From 150dce5de0413475fb216c706ecb864ef804bdde Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Wed, 14 Feb 2024 02:24:35 +0200 Subject: [PATCH 1/8] cql_server: move process_fn_return_type to class definition So it can be used for a template concept in the next patch. Signed-off-by: Benny Halevy --- transport/server.cc | 30 +++++++++++++----------------- transport/server.hh | 2 ++ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index 24bef72be2..d5989b4efa 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -986,10 +986,6 @@ cql_server::connection::process_on_shard(::shared_ptr>; - static inline cql_server::result_with_foreign_response_ptr convert_error_message_to_coordinator_result(messages::result_message* msg) { return std::move(*dynamic_cast(msg)).get_exception(); } @@ -1004,7 +1000,7 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli return process_fn(client_state, _server._query_processor, in, stream, _version, permit, trace_state, true, {}, dialect) .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] - (process_fn_return_type msg) mutable { + (cql_server::process_fn_return_type msg) mutable { auto* bounce_msg = std::get_if>(&msg); if (bounce_msg) { return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn); @@ -1014,7 +1010,7 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli }); } -static future +static future process_query_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, @@ -1041,12 +1037,12 @@ process_query_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); } }); } @@ -1078,7 +1074,7 @@ future> cql_server::connection::process_pr }); } -static future +static future process_execute_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, @@ -1139,12 +1135,12 @@ process_execute_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, skip_metadata))); } }); } @@ -1154,7 +1150,7 @@ future cql_server::connection::pro return process(stream, in, client_state, std::move(permit), std::move(trace_state), process_execute_internal); } -static future +static future process_batch_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) { @@ -1260,12 +1256,12 @@ process_batch_internal(service::client_state& client_state, distributedmove_to_shard()) { - return process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); } else if (msg->is_exception()) { - return process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); + return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result"); - return process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version))); + return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version))); } }); } diff --git a/transport/server.hh b/transport/server.hh index cd4acaf81f..60c0afac47 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -186,6 +186,8 @@ public: public: using response = cql_transport::response; using result_with_foreign_response_ptr = exceptions::coordinator_result>>; + using process_fn_return_type = std::variant>; + service::endpoint_lifecycle_subscriber* get_lifecycle_listener() const noexcept; service::migration_listener* get_migration_listener() const noexcept; qos::qos_configuration_change_subscriber* get_qos_configuration_listener() const noexcept; From 0df6f5537943891357707b90c29f0ae43f3ec2ca Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Wed, 14 Feb 2024 02:19:42 +0200 Subject: [PATCH 2/8] cql_server: connection: process: add template concept for process_fn Quoting Avi Kivity: > Out of scope: we should consider detemplating this. As a follow-up we should consider that and pass a function object as process_fn, just make sure there are no drawbacks. Signed-off-by: Benny Halevy --- transport/server.cc | 28 ++++++++++++++++++++++++++-- transport/server.hh | 31 +++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index d5989b4efa..c552879cd0 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -962,7 +962,19 @@ std::unique_ptr make_result(int16_t stream, messages::result_message& msg, const tracing::trace_state_ptr& tr_state, cql_protocol_version_type version, bool skip_metadata = false); -template +template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future cql_server::connection::process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) { @@ -990,7 +1002,19 @@ static inline cql_server::result_with_foreign_response_ptr convert_error_message return std::move(*dynamic_cast(msg)).get_exception(); } -template +template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future cql_server::connection::process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn) { diff --git a/transport/server.hh b/transport/server.hh index 60c0afac47..e8ab688e48 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -18,6 +18,7 @@ #include "timeout_config.hh" #include #include +#include #include #include #include @@ -282,11 +283,37 @@ private: cql3::dialect get_dialect() const; // Helper functions to encapsulate bounce_to_shard processing for query, execute and batch verbs - template + template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn); - template + + + template + requires std::is_invocable_r_v, + Process, + service::client_state&, + distributed&, + request_reader, + uint16_t, + cql_protocol_version_type, + service_permit, + tracing::trace_state_ptr, + bool, + cql3::computed_function_values, + cql3::dialect> future process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn); From eb7fbdbed2860b6515d7a1bdf8f8e7b496e8255b Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Wed, 14 Feb 2024 12:12:28 +0200 Subject: [PATCH 3/8] transport: server: pass bounce_to_shard as foreign shared ptr So it can safely passed between shards, as will be needed in the following patch that handles a (re)bounce_to_shard result from process_fn that's called by `process_on_shard` on the `move_to_shard`. With that in mind, pass the `bounce_to_shard` payload to `process_on_shard` rather than the foreign shared ptr since the latter grabs what it needs from it on entry and the shared_ptr can be released on the calling shard. Signed-off-by: Benny Halevy --- transport/server.cc | 28 +++++++++++++++------------- transport/server.hh | 8 ++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index c552879cd0..0e36def4f4 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -976,21 +976,21 @@ template cql3::computed_function_values, cql3::dialect> future -cql_server::connection::process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) { - return _server.container().invoke_on(*bounce_msg->move_to_shard(), _server._config.bounce_request_smp_service_group, - [this, is = std::move(is), cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, +cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, + service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { + return _server.container().invoke_on(shard, _server._config.bounce_request_smp_service_group, + [this, is, cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, gt = tracing::global_trace_state_ptr(std::move(trace_state)), - cached_vals = std::move(bounce_msg->take_cached_pk_function_calls()), dialect] (cql_server& server) { + cached_vals, dialect] (cql_server& server) { service::client_state client_state = cs.get(); - return do_with(bytes_ostream(), std::move(client_state), std::move(cached_vals), - [this, &server, is = std::move(is), stream, process_fn, + return do_with(bytes_ostream(), std::move(client_state), cached_vals, + [this, &server, is = is, stream, process_fn, trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer, service::client_state& client_state, cql3::computed_function_values& cached_vals) mutable { request_reader in(is, linearization_buffer); return process_fn(client_state, server._query_processor, in, stream, _version, - /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals), dialect).then([] (auto msg) { + /* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect).then([] (auto msg) { // result here has to be foreign ptr return std::get(std::move(msg)); }); @@ -1025,9 +1025,11 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli _version, permit, trace_state, true, {}, dialect) .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] (cql_server::process_fn_return_type msg) mutable { - auto* bounce_msg = std::get_if>(&msg); + auto* bounce_msg = std::get_if(&msg); if (bounce_msg) { - return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn); + auto shard = (*bounce_msg)->move_to_shard().value(); + auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); + return process_on_shard(shard, stream, is, client_state, std::move(permit), trace_state, dialect, std::move(cached_vals), process_fn); } auto ptr = std::get(std::move(msg)); return make_ready_future(std::move(ptr)); @@ -1061,7 +1063,7 @@ process_query_internal(service::client_state& client_state, distributedmove_to_shard()) { - return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { @@ -1159,7 +1161,7 @@ process_execute_internal(service::client_state& client_state, distributedmove_to_shard()) { - return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { @@ -1280,7 +1282,7 @@ process_batch_internal(service::client_state& client_state, distributedmove_to_shard()) { - return cql_server::process_fn_return_type(dynamic_pointer_cast(msg)); + return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast(msg))); } else if (msg->is_exception()) { return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get())); } else { diff --git a/transport/server.hh b/transport/server.hh index e8ab688e48..b69ce9b300 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -187,7 +187,8 @@ public: public: using response = cql_transport::response; using result_with_foreign_response_ptr = exceptions::coordinator_result>>; - using process_fn_return_type = std::variant>; + using result_with_bounce_to_shard = foreign_ptr>; + using process_fn_return_type = std::variant; service::endpoint_lifecycle_subscriber* get_lifecycle_listener() const noexcept; service::migration_listener* get_migration_listener() const noexcept; @@ -300,7 +301,6 @@ private: process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn); - template requires std::is_invocable_r_v, Process, @@ -315,8 +315,8 @@ private: cql3::computed_function_values, cql3::dialect> future - process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, - service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn); + process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, + service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn); void write_response(foreign_ptr>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none); From 71052dca6a3876247aba854268b1b95e3a515759 Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Wed, 14 Feb 2024 14:19:39 +0200 Subject: [PATCH 4/8] cql_server: connection: process_on_shard: drop permit parameter It is currently unused in `process_on_shard`, which generates an empty service_permit. The next patch may call process_on_shard in a loop, so it can't simply move the permit to the callee and better hold on to it until processing completes. `cql_server::connection::process` was turned into a coroutine in this patch to hold on to the permit parameter in a simple way. This is a preliminary step to changing `if (bounce_msg)` to `while (bounce_msg)` that will allow rebouncing the message in case it moved yet again when yielding in `process_on_shard`. Signed-off-by: Benny Halevy --- transport/server.cc | 17 +++++++---------- transport/server.hh | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index 0e36def4f4..6a85964652 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -977,9 +977,9 @@ template cql3::dialect> future cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { + service::client_state& cs, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { return _server.container().invoke_on(shard, _server._config.bounce_request_smp_service_group, - [this, is, cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, + [this, is, cs = cs.move_to_other_shard(), stream, process_fn, gt = tracing::global_trace_state_ptr(std::move(trace_state)), cached_vals, dialect] (cql_server& server) { service::client_state client_state = cs.get(); @@ -1021,19 +1021,16 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli fragmented_temporary_buffer::istream is = in.get_stream(); auto dialect = get_dialect(); - return process_fn(client_state, _server._query_processor, in, stream, - _version, permit, trace_state, true, {}, dialect) - .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] - (cql_server::process_fn_return_type msg) mutable { + auto msg = co_await process_fn(client_state, _server._query_processor, in, stream, + _version, permit, trace_state, true, {}, dialect); + // FIXME: indentation auto* bounce_msg = std::get_if(&msg); if (bounce_msg) { auto shard = (*bounce_msg)->move_to_shard().value(); auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); - return process_on_shard(shard, stream, is, client_state, std::move(permit), trace_state, dialect, std::move(cached_vals), process_fn); + co_return co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); } - auto ptr = std::get(std::move(msg)); - return make_ready_future(std::move(ptr)); - }); + co_return std::get(std::move(msg)); } static future diff --git a/transport/server.hh b/transport/server.hh index b69ce9b300..8cd4488a93 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -316,7 +316,7 @@ private: cql3::dialect> future process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, - service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn); + tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn); void write_response(foreign_ptr>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none); From 0b93409b4446d3dfb8f237ccc6753da2c2cceaa8 Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Wed, 14 Feb 2024 14:28:40 +0200 Subject: [PATCH 5/8] cql_server: connection: process: fixup indentation Signed-off-by: Benny Halevy --- transport/server.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index 6a85964652..bfc857dd97 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -1023,14 +1023,13 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli auto dialect = get_dialect(); auto msg = co_await process_fn(client_state, _server._query_processor, in, stream, _version, permit, trace_state, true, {}, dialect); - // FIXME: indentation - auto* bounce_msg = std::get_if(&msg); - if (bounce_msg) { - auto shard = (*bounce_msg)->move_to_shard().value(); - auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); - co_return co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); - } - co_return std::get(std::move(msg)); + auto* bounce_msg = std::get_if(&msg); + if (bounce_msg) { + auto shard = (*bounce_msg)->move_to_shard().value(); + auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); + co_return co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); + } + co_return std::get(std::move(msg)); } static future From f674f522aa36a53c314d2df1499f6e7d7818cd6c Mon Sep 17 00:00:00 2001 From: Sergey Zolotukhin Date: Tue, 17 Sep 2024 14:54:30 +0200 Subject: [PATCH 6/8] cql_server::connection: process: co-routinize connection::process_on_shard `cql_server::connection::process_on_shard` is made a co-routine to make sure captured objects' lifetime is managed by the source shard, avoiding error prone inter-shard objects transfers. --- transport/server.cc | 35 ++++++++++++++--------------------- transport/server.hh | 2 +- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index bfc857dd97..ed6e9d5573 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -975,26 +975,19 @@ template bool, cql3::computed_function_values, cql3::dialect> -future -cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { - return _server.container().invoke_on(shard, _server._config.bounce_request_smp_service_group, - [this, is, cs = cs.move_to_other_shard(), stream, process_fn, - gt = tracing::global_trace_state_ptr(std::move(trace_state)), - cached_vals, dialect] (cql_server& server) { - service::client_state client_state = cs.get(); - return do_with(bytes_ostream(), std::move(client_state), cached_vals, - [this, &server, is = is, stream, process_fn, - trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer, - service::client_state& client_state, - cql3::computed_function_values& cached_vals) mutable { - request_reader in(is, linearization_buffer); - return process_fn(client_state, server._query_processor, in, stream, _version, - /* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect).then([] (auto msg) { - // result here has to be foreign ptr - return std::get(std::move(msg)); - }); - }); +future +cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, + tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) { + auto sg = _server._config.bounce_request_smp_service_group; + auto gcs = cs.move_to_other_shard(); + auto gt = tracing::global_trace_state_ptr(std::move(trace_state)); + co_return co_await _server.container().invoke_on(shard, sg, [&, stream, dialect] (cql_server& server) -> future { + bytes_ostream linearization_buffer; + request_reader in(is, linearization_buffer); + auto client_state = gcs.get(); + auto trace_state = gt.get(); + co_return co_await process_fn(client_state, server._query_processor, in, stream, _version, + /* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect); }); } @@ -1027,7 +1020,7 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli if (bounce_msg) { auto shard = (*bounce_msg)->move_to_shard().value(); auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); - co_return co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); + msg = co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); } co_return std::get(std::move(msg)); } diff --git a/transport/server.hh b/transport/server.hh index 8cd4488a93..52cf61b050 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -314,7 +314,7 @@ private: bool, cql3::computed_function_values, cql3::dialect> - future + future process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn); From 65430b9e1b6d9d3f0acc35fb034977570e5dc8f9 Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Tue, 13 Feb 2024 12:21:16 +0200 Subject: [PATCH 7/8] cql_server::connection: process: rebounce msg if needed Rebounce the msg to another shard if needed, e.g. in the case of tablet migration. An example for that, as given by Tomasz Grabiec: > Bouncing happens when executing LWT statement in > modification_statement::execute_with_condition by returning a > special result message kind. The code assumes that after > jumping to the shard from the bounce request, the result > message is the regular one and not yet another bounce. > There is no problem with vnodes, because shards don't change. > With tablets, they can change at run time on migration. Fixes scylladb/scylladb#15465 Signed-off-by: Benny Halevy --- transport/server.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index ed6e9d5573..af1f3bbe0b 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -1016,8 +1016,7 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli auto dialect = get_dialect(); auto msg = co_await process_fn(client_state, _server._query_processor, in, stream, _version, permit, trace_state, true, {}, dialect); - auto* bounce_msg = std::get_if(&msg); - if (bounce_msg) { + while (auto* bounce_msg = std::get_if(&msg)) { auto shard = (*bounce_msg)->move_to_shard().value(); auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls(); msg = co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn); From 68740f57c2bd13a493106d694ec2e119dc0d56b0 Mon Sep 17 00:00:00 2001 From: Sergey Zolotukhin Date: Thu, 5 Sep 2024 12:11:11 +0200 Subject: [PATCH 8/8] cql_server: Add a test for multiple query msg rebounces. The test emulates several LWT(Lightweight Transaction) query rebounces. Currently, the code that processes queries does not expect that a query may be rebounced more than once. It was impossible with the VNodes, but with intruduction of the Tablets, data can be moved between shards by the balancer thus a query can be rebounced to different shards multiple times. --- cql3/statements/modification_statement.cc | 61 +++++++++++++++++++ test/topology_custom/test_query_rebounce.py | 65 +++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 test/topology_custom/test_query_rebounce.py diff --git a/cql3/statements/modification_statement.cc b/cql3/statements/modification_statement.cc index 699bac8503..dc756936b6 100644 --- a/cql3/statements/modification_statement.cc +++ b/cql3/statements/modification_statement.cc @@ -316,6 +316,62 @@ modification_statement::execute_without_condition(query_processor& qp, service:: }); } +namespace { + +future<::shared_ptr> +process_forced_rebounce(unsigned shard, query_processor& qp, const query_options& options) { + static int64_t counter = {0}; + static logging::logger logger("modification_statement"); + if (counter <= 0) { + const auto counter_opt = utils::get_local_injector().inject_parameter("forced_bounce_to_shard_counter"); + decltype(counter) counter_value = 0; + if (!counter_opt) { + logger.warn("forced_bounce_to_shard_counter is not set. Using default value 1."); + } else { + try { + counter_value = boost::lexical_cast(*counter_opt); + } catch (const boost::bad_lexical_cast& e) { + logger.warn("Incorrect forced_bounce_to_shard_counter value: [{}]. Using default value 1.", *counter_opt); + } + } + if (counter_value <= 0) { + counter_value = 1; + } + counter = counter_value; + } + + const auto prev_counter_value = counter; + if (prev_counter_value <= 1) { + logger.info("Disabling forced_bounce_to_shard_counter."); + co_await utils::error_injection_type::disable_on_all("forced_bounce_to_shard_counter"); + counter = 0; + } else { + --counter; + } + + // While counter > 1 select a different shard to re-bounce to. + // On the last iteration, re-bounce to the correct shard. + if (counter != 0) { + const auto shard_num = smp::count; + assert(shard_num > 0); + const auto local_shard = this_shard_id(); + auto target_shard = local_shard + 1; + if (target_shard == shard) { + ++target_shard; + } + if (target_shard > shard_num - 1) { + target_shard = 0; + } + shard = target_shard; + } + + logger.info("Applying forced_bounce_to_shard_counter, re-bouncing to shard {}.", shard); + co_return co_await make_ready_future>( + qp.bounce_to_shard(shard, std::move(const_cast(options).take_cached_pk_function_calls()))); +} + +} // namespace + future<::shared_ptr> modification_statement::execute_with_condition(query_processor& qp, service::query_state& qs, const query_options& options) const { @@ -349,6 +405,11 @@ modification_statement::execute_with_condition(query_processor& qp, service::que auto token = request->key()[0].start()->value().as_decorated_key().token(); auto shard = service::storage_proxy::cas_shard(*s, token); + + if (utils::get_local_injector().is_enabled("forced_bounce_to_shard_counter")) { + return process_forced_rebounce(shard, qp, options); + } + if (shard != this_shard_id()) { return make_ready_future>( qp.bounce_to_shard(shard, std::move(const_cast(options).take_cached_pk_function_calls())) diff --git a/test/topology_custom/test_query_rebounce.py b/test/topology_custom/test_query_rebounce.py new file mode 100644 index 0000000000..fc4d6b4d30 --- /dev/null +++ b/test/topology_custom/test_query_rebounce.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2024-present ScyllaDB +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# + +import asyncio +import pytest +import logging +import time + +from test.pylib.internal_types import IPAddress +from test.pylib.manager_client import ManagerClient +from test.pylib.rest_client import inject_error + + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_query_rebounce(manager: ManagerClient): + """ + Issue https://github.com/scylladb/scylladb/issues/15465. + + Test emulating several LWT(Lightweight Transaction) query rebounces. Currently, the code + that processes queries does not expect that a query may be rebounced more than once. + It was impossible with the VNodes, but with intruduction of the Tablets, data can be moved + between shards by the balancer thus a query can be rebounced to different shards multiple times. + + 1) Create a keyspace and a table. + 2) Insert some data. + 3) Inject an error to force a rebounce 2 times. + 4) Update the data with a LWT query. The update will fail on this step with the current implementation. + 5) Check the result of update. + + """ + + cmdline = [ + '--logger-log-level', 'raft=trace', + ] + + servers = await manager.servers_add(1, cmdline=cmdline) + + servers = await manager.running_servers() + cql = manager.get_cql() + + await cql.run_async("create keyspace ks with replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" + "and tablets = {'enabled': false};") + + await cql.run_async("create table ks.lwt (a int, b int, primary key(a));") + + await cql.run_async("insert into ks.lwt (a,b ) values (1, 10);") + await cql.run_async("insert into ks.lwt (a,b ) values (2, 20);") + + errs = [manager.api.enable_injection(s.ip_addr, "forced_bounce_to_shard_counter", one_shot=False, + parameters={'value': '2'}) + for s in servers] + + await asyncio.gather(*errs) + + await cql.run_async("update ks.lwt set b = 11 where a = 1 if b = 10;") + + rows = await cql.run_async("select b from ks.lwt where a = 1;") + + assert rows[0].b == 11