Merge 'cql_server::connection: Process rebounce message in case of multiple shard migrations' from Sergey Zolotukhin

During a query execution, the query can be re-bounced to another shard if the requested data is located there. Previous implementation assumed that the shard cannot be changed after first re-bounce, however with the introduction of Tablets, data could be migrated to another shard after the query was already re-bounced, causing a failure of the query execution. To avoid this issue, the query is re-bounced as needed until it is executed on the correct shard.

Fixes #15465

Closes scylladb/scylladb#20493

* github.com:scylladb/scylladb:
  cql_server: Add a test for multiple query msg rebounces.
  cql_server::connection: process: rebounce msg if needed
  cql_server::connection: process: co-routinize connection::process_on_shard
  cql_server: connection: process: fixup indentation
  cql_server: connection: process_on_shard: drop permit parameter
  transport: server: pass bounce_to_shard as foreign shared ptr
  cql_server: connection: process: add template concept for process_fn
  cql_server: move process_fn_return_type to class definition
This commit is contained in:
Avi Kivity
2024-09-19 17:27:54 +03:00
4 changed files with 219 additions and 54 deletions

View File

@@ -316,6 +316,62 @@ modification_statement::execute_without_condition(query_processor& qp, service::
});
}
namespace {
future<::shared_ptr<cql_transport::messages::result_message>>
process_forced_rebounce(unsigned shard, query_processor& qp, const query_options& options) {
static int64_t counter = {0};
static logging::logger logger("modification_statement");
if (counter <= 0) {
const auto counter_opt = utils::get_local_injector().inject_parameter<decltype(counter)>("forced_bounce_to_shard_counter");
decltype(counter) counter_value = 0;
if (!counter_opt) {
logger.warn("forced_bounce_to_shard_counter is not set. Using default value 1.");
} else {
try {
counter_value = boost::lexical_cast<decltype(counter_value)>(*counter_opt);
} catch (const boost::bad_lexical_cast& e) {
logger.warn("Incorrect forced_bounce_to_shard_counter value: [{}]. Using default value 1.", *counter_opt);
}
}
if (counter_value <= 0) {
counter_value = 1;
}
counter = counter_value;
}
const auto prev_counter_value = counter;
if (prev_counter_value <= 1) {
logger.info("Disabling forced_bounce_to_shard_counter.");
co_await utils::error_injection_type::disable_on_all("forced_bounce_to_shard_counter");
counter = 0;
} else {
--counter;
}
// While counter > 1 select a different shard to re-bounce to.
// On the last iteration, re-bounce to the correct shard.
if (counter != 0) {
const auto shard_num = smp::count;
assert(shard_num > 0);
const auto local_shard = this_shard_id();
auto target_shard = local_shard + 1;
if (target_shard == shard) {
++target_shard;
}
if (target_shard > shard_num - 1) {
target_shard = 0;
}
shard = target_shard;
}
logger.info("Applying forced_bounce_to_shard_counter, re-bouncing to shard {}.", shard);
co_return co_await make_ready_future<shared_ptr<cql_transport::messages::result_message>>(
qp.bounce_to_shard(shard, std::move(const_cast<cql3::query_options&>(options).take_cached_pk_function_calls())));
}
} // namespace
future<::shared_ptr<cql_transport::messages::result_message>>
modification_statement::execute_with_condition(query_processor& qp, service::query_state& qs, const query_options& options) const {
@@ -349,6 +405,11 @@ modification_statement::execute_with_condition(query_processor& qp, service::que
auto token = request->key()[0].start()->value().as_decorated_key().token();
auto shard = service::storage_proxy::cas_shard(*s, token);
if (utils::get_local_injector().is_enabled("forced_bounce_to_shard_counter")) {
return process_forced_rebounce(shard, qp, options);
}
if (shard != this_shard_id()) {
return make_ready_future<shared_ptr<cql_transport::messages::result_message>>(
qp.bounce_to_shard(shard, std::move(const_cast<cql3::query_options&>(options).take_cached_pk_function_calls()))

View File

@@ -0,0 +1,65 @@
#
# Copyright (C) 2024-present ScyllaDB
#
# SPDX-License-Identifier: AGPL-3.0-or-later
#
import asyncio
import pytest
import logging
import time
from test.pylib.internal_types import IPAddress
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import inject_error
logger = logging.getLogger(__name__)
@pytest.mark.asyncio
async def test_query_rebounce(manager: ManagerClient):
"""
Issue https://github.com/scylladb/scylladb/issues/15465.
Test emulating several LWT(Lightweight Transaction) query rebounces. Currently, the code
that processes queries does not expect that a query may be rebounced more than once.
It was impossible with the VNodes, but with intruduction of the Tablets, data can be moved
between shards by the balancer thus a query can be rebounced to different shards multiple times.
1) Create a keyspace and a table.
2) Insert some data.
3) Inject an error to force a rebounce 2 times.
4) Update the data with a LWT query. The update will fail on this step with the current implementation.
5) Check the result of update.
"""
cmdline = [
'--logger-log-level', 'raft=trace',
]
servers = await manager.servers_add(1, cmdline=cmdline)
servers = await manager.running_servers()
cql = manager.get_cql()
await cql.run_async("create keyspace ks with replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
"and tablets = {'enabled': false};")
await cql.run_async("create table ks.lwt (a int, b int, primary key(a));")
await cql.run_async("insert into ks.lwt (a,b ) values (1, 10);")
await cql.run_async("insert into ks.lwt (a,b ) values (2, 20);")
errs = [manager.api.enable_injection(s.ip_addr, "forced_bounce_to_shard_counter", one_shot=False,
parameters={'value': '2'})
for s in servers]
await asyncio.gather(*errs)
await cql.run_async("update ks.lwt set b = 11 where a = 1 if b = 10;")
rows = await cql.run_async("select b from ks.lwt where a = 1;")
assert rows[0].b == 11

View File

@@ -962,59 +962,69 @@ std::unique_ptr<cql_server::response>
make_result(int16_t stream, messages::result_message& msg, const tracing::trace_state_ptr& tr_state,
cql_protocol_version_type version, bool skip_metadata = false);
template<typename Process>
future<cql_server::result_with_foreign_response_ptr>
cql_server::connection::process_on_shard(::shared_ptr<messages::result_message::bounce_to_shard> bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is,
service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) {
return _server.container().invoke_on(*bounce_msg->move_to_shard(), _server._config.bounce_request_smp_service_group,
[this, is = std::move(is), cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn,
gt = tracing::global_trace_state_ptr(std::move(trace_state)),
cached_vals = std::move(bounce_msg->take_cached_pk_function_calls()), dialect] (cql_server& server) {
service::client_state client_state = cs.get();
return do_with(bytes_ostream(), std::move(client_state), std::move(cached_vals),
[this, &server, is = std::move(is), stream, process_fn,
trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer,
service::client_state& client_state,
cql3::computed_function_values& cached_vals) mutable {
request_reader in(is, linearization_buffer);
return process_fn(client_state, server._query_processor, in, stream, _version,
/* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals), dialect).then([] (auto msg) {
// result here has to be foreign ptr
return std::get<cql_server::result_with_foreign_response_ptr>(std::move(msg));
});
});
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<cql_server::process_fn_return_type>
cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs,
tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) {
auto sg = _server._config.bounce_request_smp_service_group;
auto gcs = cs.move_to_other_shard();
auto gt = tracing::global_trace_state_ptr(std::move(trace_state));
co_return co_await _server.container().invoke_on(shard, sg, [&, stream, dialect] (cql_server& server) -> future<process_fn_return_type> {
bytes_ostream linearization_buffer;
request_reader in(is, linearization_buffer);
auto client_state = gcs.get();
auto trace_state = gt.get();
co_return co_await process_fn(client_state, server._query_processor, in, stream, _version,
/* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect);
});
}
using process_fn_return_type = std::variant<
cql_server::result_with_foreign_response_ptr,
::shared_ptr<messages::result_message::bounce_to_shard>>;
static inline cql_server::result_with_foreign_response_ptr convert_error_message_to_coordinator_result(messages::result_message* msg) {
return std::move(*dynamic_cast<messages::result_message::exception*>(msg)).get_exception();
}
template<typename Process>
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<cql_server::result_with_foreign_response_ptr>
cql_server::connection::process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit,
tracing::trace_state_ptr trace_state, Process process_fn) {
fragmented_temporary_buffer::istream is = in.get_stream();
auto dialect = get_dialect();
return process_fn(client_state, _server._query_processor, in, stream,
_version, permit, trace_state, true, {}, dialect)
.then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect]
(process_fn_return_type msg) mutable {
auto* bounce_msg = std::get_if<shared_ptr<messages::result_message::bounce_to_shard>>(&msg);
if (bounce_msg) {
return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn);
}
auto ptr = std::get<cql_server::result_with_foreign_response_ptr>(std::move(msg));
return make_ready_future<cql_server::result_with_foreign_response_ptr>(std::move(ptr));
});
auto msg = co_await process_fn(client_state, _server._query_processor, in, stream,
_version, permit, trace_state, true, {}, dialect);
while (auto* bounce_msg = std::get_if<result_with_bounce_to_shard>(&msg)) {
auto shard = (*bounce_msg)->move_to_shard().value();
auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls();
msg = co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn);
}
co_return std::get<cql_server::result_with_foreign_response_ptr>(std::move(msg));
}
static future<process_fn_return_type>
static future<cql_server::process_fn_return_type>
process_query_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls,
@@ -1041,12 +1051,12 @@ process_query_internal(service::client_state& client_state, distributed<cql3::qu
return qp.local().execute_direct_without_checking_exception_message(query, query_state, dialect, options).then([q_state = std::move(q_state), stream, skip_metadata, version] (auto msg) {
if (msg->move_to_shard()) {
return process_fn_return_type(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg));
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata)));
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata)));
}
});
}
@@ -1078,7 +1088,7 @@ future<std::unique_ptr<cql_server::response>> cql_server::connection::process_pr
});
}
static future<process_fn_return_type>
static future<cql_server::process_fn_return_type>
process_execute_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls,
@@ -1139,12 +1149,12 @@ process_execute_internal(service::client_state& client_state, distributed<cql3::
return qp.local().execute_prepared_without_checking_exception_message(query_state, std::move(stmt), options, std::move(prepared), std::move(cache_key), needs_authorization)
.then([trace_state = query_state.get_trace_state(), skip_metadata, q_state = std::move(q_state), stream, version] (auto msg) {
if (msg->move_to_shard()) {
return process_fn_return_type(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg));
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata)));
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata)));
}
});
}
@@ -1154,7 +1164,7 @@ future<cql_server::result_with_foreign_response_ptr> cql_server::connection::pro
return process(stream, in, client_state, std::move(permit), std::move(trace_state), process_execute_internal);
}
static future<process_fn_return_type>
static future<cql_server::process_fn_return_type>
process_batch_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) {
@@ -1260,12 +1270,12 @@ process_batch_internal(service::client_state& client_state, distributed<cql3::qu
return qp.local().execute_batch_without_checking_exception_message(batch, query_state, options, std::move(pending_authorization_entries))
.then([stream, batch, q_state = std::move(q_state), trace_state = query_state.get_trace_state(), version] (auto msg) {
if (msg->move_to_shard()) {
return process_fn_return_type(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg));
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version)));
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version)));
}
});
}

View File

@@ -18,6 +18,7 @@
#include "timeout_config.hh"
#include <seastar/core/semaphore.hh>
#include <memory>
#include <type_traits>
#include <boost/intrusive/list.hpp>
#include <seastar/net/tls.hh>
#include <seastar/core/metrics_registration.hh>
@@ -186,6 +187,9 @@ public:
public:
using response = cql_transport::response;
using result_with_foreign_response_ptr = exceptions::coordinator_result<foreign_ptr<std::unique_ptr<cql_server::response>>>;
using result_with_bounce_to_shard = foreign_ptr<seastar::shared_ptr<messages::result_message::bounce_to_shard>>;
using process_fn_return_type = std::variant<result_with_foreign_response_ptr, result_with_bounce_to_shard>;
service::endpoint_lifecycle_subscriber* get_lifecycle_listener() const noexcept;
service::migration_listener* get_migration_listener() const noexcept;
qos::qos_configuration_change_subscriber* get_qos_configuration_listener() const noexcept;
@@ -280,14 +284,39 @@ private:
cql3::dialect get_dialect() const;
// Helper functions to encapsulate bounce_to_shard processing for query, execute and batch verbs
template<typename Process>
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<result_with_foreign_response_ptr>
process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state,
Process process_fn);
template<typename Process>
future<result_with_foreign_response_ptr>
process_on_shard(::shared_ptr<messages::result_message::bounce_to_shard> bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs,
service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn);
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<process_fn_return_type>
process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs,
tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn);
void write_response(foreign_ptr<std::unique_ptr<cql_server::response>>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none);