diff --git a/cql3/statements/modification_statement.cc b/cql3/statements/modification_statement.cc index 09dd310281..4908723184 100644 --- a/cql3/statements/modification_statement.cc +++ b/cql3/statements/modification_statement.cc @@ -316,6 +316,62 @@ modification_statement::execute_without_condition(query_processor& qp, service:: }); } +namespace { + +future<::shared_ptr> +process_forced_rebounce(unsigned shard, query_processor& qp, const query_options& options) { + static int64_t counter = {0}; + static logging::logger logger("modification_statement"); + if (counter <= 0) { + const auto counter_opt = utils::get_local_injector().inject_parameter("forced_bounce_to_shard_counter"); + decltype(counter) counter_value = 0; + if (!counter_opt) { + logger.warn("forced_bounce_to_shard_counter is not set. Using default value 1."); + } else { + try { + counter_value = boost::lexical_cast(*counter_opt); + } catch (const boost::bad_lexical_cast& e) { + logger.warn("Incorrect forced_bounce_to_shard_counter value: [{}]. Using default value 1.", *counter_opt); + } + } + if (counter_value <= 0) { + counter_value = 1; + } + counter = counter_value; + } + + const auto prev_counter_value = counter; + if (prev_counter_value <= 1) { + logger.info("Disabling forced_bounce_to_shard_counter."); + co_await utils::error_injection_type::disable_on_all("forced_bounce_to_shard_counter"); + counter = 0; + } else { + --counter; + } + + // While counter > 1 select a different shard to re-bounce to. + // On the last iteration, re-bounce to the correct shard. + if (counter != 0) { + const auto shard_num = smp::count; + assert(shard_num > 0); + const auto local_shard = this_shard_id(); + auto target_shard = local_shard + 1; + if (target_shard == shard) { + ++target_shard; + } + if (target_shard > shard_num - 1) { + target_shard = 0; + } + shard = target_shard; + } + + logger.info("Applying forced_bounce_to_shard_counter, re-bouncing to shard {}.", shard); + co_return co_await make_ready_future>( + qp.bounce_to_shard(shard, std::move(const_cast(options).take_cached_pk_function_calls()))); +} + +} // namespace + future<::shared_ptr> modification_statement::execute_with_condition(query_processor& qp, service::query_state& qs, const query_options& options) const { @@ -349,6 +405,11 @@ modification_statement::execute_with_condition(query_processor& qp, service::que auto token = request->key()[0].start()->value().as_decorated_key().token(); auto shard = service::storage_proxy::cas_shard(*s, token); + + if (utils::get_local_injector().is_enabled("forced_bounce_to_shard_counter")) { + return process_forced_rebounce(shard, qp, options); + } + if (shard != this_shard_id()) { return make_ready_future>( qp.bounce_to_shard(shard, std::move(const_cast(options).take_cached_pk_function_calls())) diff --git a/test/topology_custom/test_query_rebounce.py b/test/topology_custom/test_query_rebounce.py new file mode 100644 index 0000000000..fc4d6b4d30 --- /dev/null +++ b/test/topology_custom/test_query_rebounce.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2024-present ScyllaDB +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# + +import asyncio +import pytest +import logging +import time + +from test.pylib.internal_types import IPAddress +from test.pylib.manager_client import ManagerClient +from test.pylib.rest_client import inject_error + + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_query_rebounce(manager: ManagerClient): + """ + Issue https://github.com/scylladb/scylladb/issues/15465. + + Test emulating several LWT(Lightweight Transaction) query rebounces. Currently, the code + that processes queries does not expect that a query may be rebounced more than once. + It was impossible with the VNodes, but with intruduction of the Tablets, data can be moved + between shards by the balancer thus a query can be rebounced to different shards multiple times. + + 1) Create a keyspace and a table. + 2) Insert some data. + 3) Inject an error to force a rebounce 2 times. + 4) Update the data with a LWT query. The update will fail on this step with the current implementation. + 5) Check the result of update. + + """ + + cmdline = [ + '--logger-log-level', 'raft=trace', + ] + + servers = await manager.servers_add(1, cmdline=cmdline) + + servers = await manager.running_servers() + cql = manager.get_cql() + + await cql.run_async("create keyspace ks with replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" + "and tablets = {'enabled': false};") + + await cql.run_async("create table ks.lwt (a int, b int, primary key(a));") + + await cql.run_async("insert into ks.lwt (a,b ) values (1, 10);") + await cql.run_async("insert into ks.lwt (a,b ) values (2, 20);") + + errs = [manager.api.enable_injection(s.ip_addr, "forced_bounce_to_shard_counter", one_shot=False, + parameters={'value': '2'}) + for s in servers] + + await asyncio.gather(*errs) + + await cql.run_async("update ks.lwt set b = 11 where a = 1 if b = 10;") + + rows = await cql.run_async("select b from ks.lwt where a = 1;") + + assert rows[0].b == 11 diff --git a/transport/server.cc b/transport/server.cc index 24bef72be2..af1f3bbe0b 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -962,59 +962,69 @@ std::unique_ptr make_result(int16_t stream, messages::result_message& msg, const tracing::trace_state_ptr& tr_state, cql_protocol_version_type version, bool skip_metadata = false); -template -future -cql_server::connection::process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) { - return _server.container().invoke_on(*bounce_msg->move_to_shard(), _server._config.bounce_request_smp_service_group, - [this, is = std::move(is), cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, - gt = tracing::global_trace_state_ptr(std::move(trace_state)), - cached_vals = std::move(bounce_msg->take_cached_pk_function_calls()), dialect] (cql_server& server) { - service::client_state client_state = cs.get(); - return do_with(bytes_ostream(), std::move(client_state), std::move(cached_vals), - [this, &server, is = std::move(is), stream, process_fn, - trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer, - service::client_state& client_state, - cql3::computed_function_values& cached_vals) mutable { - request_reader in(is, linearization_buffer); - return process_fn(client_state, server._query_processor, in, stream, _version, - /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals), dialect).then([] (auto msg) { - // result here has to be foreign ptr - return std::get(std::move(msg)); - }); - }); +template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> +future +cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, + tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { + auto sg = _server._config.bounce_request_smp_service_group; + auto gcs = cs.move_to_other_shard(); + auto gt = tracing::global_trace_state_ptr(std::move(trace_state)); + co_return co_await _server.container().invoke_on(shard, sg, [&, stream, dialect] (cql_server& server) -> future { + bytes_ostream linearization_buffer; + request_reader in(is, linearization_buffer); + auto client_state = gcs.get(); + auto trace_state = gt.get(); + co_return co_await process_fn(client_state, server._query_processor, in, stream, _version, + /* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect); }); } -using process_fn_return_type = std::variant< - cql_server::result_with_foreign_response_ptr, - ::shared_ptr>; - static inline cql_server::result_with_foreign_response_ptr convert_error_message_to_coordinator_result(messages::result_message* msg) { return std::move(*dynamic_cast(msg)).get_exception(); } -template +template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future cql_server::connection::process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn) { fragmented_temporary_buffer::istream is = in.get_stream(); auto dialect = get_dialect(); - return process_fn(client_state, _server._query_processor, in, stream, - _version, permit, trace_state, true, {}, dialect) - .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] - (process_fn_return_type msg) mutable { - auto* bounce_msg = std::get_if>(&msg); - if (bounce_msg) { - return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn); - } - auto ptr = std::get(std::move(msg)); - return make_ready_future(std::move(ptr)); - }); + auto msg = co_await process_fn(client_state, _server._query_processor, in, stream, + _version, permit, trace_state, true, {}, dialect); + while (auto* bounce_msg = std::get_if(&msg)) { + auto shard = (*bounce_msg)->move_to_shard().value(); + auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); + msg = co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); + } + co_return std::get(std::move(msg)); } -static future +static future process_query_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, @@ -1041,12 +1051,12 @@ process_query_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); } }); } @@ -1078,7 +1088,7 @@ future> cql_server::connection::process_pr }); } -static future +static future process_execute_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, @@ -1139,12 +1149,12 @@ process_execute_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); } }); } @@ -1154,7 +1164,7 @@ future cql_server::connection::pro return process(stream, in, client_state, std::move(permit), std::move(trace_state), process_execute_internal); } -static future +static future process_batch_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) { @@ -1260,12 +1270,12 @@ process_batch_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version))); } }); } diff --git a/transport/server.hh b/transport/server.hh index cd4acaf81f..52cf61b050 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -18,6 +18,7 @@ #include "timeout_config.hh" #include #include +#include #include #include #include @@ -186,6 +187,9 @@ public: public: using response = cql_transport::response; using result_with_foreign_response_ptr = exceptions::coordinator_result>>; + using result_with_bounce_to_shard = foreign_ptr>; + using process_fn_return_type = std::variant; + service::endpoint_lifecycle_subscriber* get_lifecycle_listener() const noexcept; service::migration_listener* get_migration_listener() const noexcept; qos::qos_configuration_change_subscriber* get_qos_configuration_listener() const noexcept; @@ -280,14 +284,39 @@ private: cql3::dialect get_dialect() const; // Helper functions to encapsulate bounce_to_shard processing for query, execute and batch verbs - template + template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn); - template - future - process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, - service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn); + + template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> + future + process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, + tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn); void write_response(foreign_ptr>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none);