From 4115a6fece30dcef68e17015ff64cdddb5469d36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Chojnowski?= Date: Mon, 17 Mar 2025 18:24:46 +0100 Subject: [PATCH] storage_service: add some dict-related routines storage_service will be the interface between the API layer (or the automatic training loop) and the dict machinery. This commit implements the relevant interface for that. It adds methods that: 1. Take SSTable samples from the cluster, using the new RPC verbs. 2. Train a dict on the sample. (The trainer will be plugged in from `main`). 3. Publishes the trained dictionary. (By adding mutations to Raft group 0). Perhaps this should be moved to a separate "service". But it's not like `storage_service` has a clear purpose anyway. --- service/storage_service.cc | 54 ++++++++++++++++++++++++++++++++++++++ service/storage_service.hh | 11 ++++++++ 2 files changed, 65 insertions(+) diff --git a/service/storage_service.cc b/service/storage_service.cc index 5a01abab39..bacad9ec92 100644 --- a/service/storage_service.cc +++ b/service/storage_service.cc @@ -4790,6 +4790,60 @@ semaphore& storage_service::get_do_sample_sstables_concurrency_limiter() { return _do_sample_sstables_concurrency_limiter; } +future storage_service::estimate_total_sstable_volume(table_id t) { + co_return co_await seastar::map_reduce( + _db.local().get_token_metadata().get_host_ids(), + [&] (auto h) -> future { + return ser::storage_service_rpc_verbs::send_estimate_sstable_volume(&_messaging.local(), h, t); + }, + uint64_t(0), + std::plus() + ); +} + +future> storage_service::train_dict(utils::chunked_vector> sample) { + std::vector> tmp; + tmp.reserve(sample.size()); + for (const auto& s : sample) { + auto v = std::as_bytes(std::span(s)); + tmp.push_back(std::vector(v.begin(), v.end())); + } + co_return co_await container().invoke_on(0, [tmp = std::move(tmp)] (auto& local) { + if (!local._train_dict) { + on_internal_error(slogger, "retrain_dict: _train_dict not plugged"); + } + return local._train_dict(std::move(tmp)); + }); +} + +future<> storage_service::publish_new_sstable_dict(table_id t_id, std::span dict, service::raft_group0_client& group0_client) { + co_await container().invoke_on(0, coroutine::lambda([t_id, dict, &group0_client] (storage_service& local_ss) -> future<> { + while (true) { + try { + auto name = fmt::format("sstables/{}", t_id); + slogger.debug("publish_new_sstable_dict: trying to publish the dict as {}", name); + auto batch = service::group0_batch(co_await group0_client.start_operation(local_ss.get_abort_source())); + auto write_ts = batch.write_timestamp(); + auto new_dict_ts = db_clock::now(); + auto data = bytes(reinterpret_cast(dict.data()), dict.size()); + auto this_host_id = local_ss._db.local().get_token_metadata().get_topology().get_config().this_host_id; + mutation publish_new_dict = co_await local_ss._sys_ks.local().get_insert_dict_mutation(name, std::move(data), this_host_id, new_dict_ts, write_ts); + batch.add_mutation(std::move(publish_new_dict), "publish new SSTable compression dictionary"); + slogger.debug("publish_new_sstable_dict: committing"); + co_await std::move(batch).commit(group0_client, local_ss.get_abort_source(), {}); + slogger.debug("publish_new_sstable_dict: finished"); + break; + } catch (const service::group0_concurrent_modification&) { + slogger.debug("group0_concurrent_modification in publish_new_sstable_dict, retrying"); + } + } + })); +} + +void storage_service::set_train_dict_callback(decltype(_train_dict) cb) { + _train_dict = std::move(cb); +} + future>> storage_service::do_sample_sstables(table_id t, uint64_t chunk_size, uint64_t n_chunks) { uint64_t max_chunks_per_round = 16 * 1024 * 1024 / chunk_size; uint64_t chunks_done = 0; diff --git a/service/storage_service.hh b/service/storage_service.hh index 06300cc73f..3a2b249839 100644 --- a/service/storage_service.hh +++ b/service/storage_service.hh @@ -323,6 +323,10 @@ public: return *_shared_token_metadata.get(); } + abort_source& get_abort_source() noexcept { + return _abort_source; + } + private: inet_address get_broadcast_address() const noexcept { return get_token_metadata_ptr()->get_topology().my_address(); @@ -1011,6 +1015,13 @@ private: abort_source _group0_as; std::function(std::string_view)> _compression_dictionary_updated_callback; + using byte_vector = std::vector; + std::function(std::vector)> _train_dict; +public: + future estimate_total_sstable_volume(table_id); + future> train_dict(utils::chunked_vector> sample); + future<> publish_new_sstable_dict(table_id, std::span, service::raft_group0_client&); + void set_train_dict_callback(decltype(_train_dict)); utils::disk_space_monitor* _disk_space_monitor; // != nullptr only on shard0.