From 7bf53fc9089ce8183fccc4318285caf26185004b Mon Sep 17 00:00:00 2001 From: Pawel Pery Date: Wed, 2 Jul 2025 19:26:32 +0200 Subject: [PATCH 1/5] vector_store_client: implement initial vector_store_client service This patch is a part of vector_store_client sharded service implementation for a communication with vector-store service. It adds a `services/vector_store_client.{cc|hh}` sharded service and a configuration parameter `vector_store_uri` with a `http://vector-store.dns.name:port` format. If there will be an error during parsing that parameter there will be an exception during construction. For the future unit testing purposes the patch adds `vector_store_client_tester` as a way to inject mockup functionality. This service will be used by the select statements for the Vector search indexes (see VS-46). For this reason I've added vector_store_client service in the query processor. Reference: VS-47 VS-45 --- conf/scylla.yaml | 7 ++ configure.py | 2 + cql3/query_processor.cc | 4 +- cql3/query_processor.hh | 12 ++- db/config.cc | 1 + db/config.hh | 2 + main.cc | 14 +++- service/CMakeLists.txt | 3 +- service/vector_store_client.cc | 111 +++++++++++++++++++++++++ service/vector_store_client.hh | 63 ++++++++++++++ test/boost/CMakeLists.txt | 2 + test/boost/vector_store_client_test.cc | 64 ++++++++++++++ test/lib/cql_test_env.cc | 8 +- 13 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 service/vector_store_client.cc create mode 100644 service/vector_store_client.hh create mode 100644 test/boost/vector_store_client_test.cc diff --git a/conf/scylla.yaml b/conf/scylla.yaml index fd6c9e9bd9..00bbd53f7f 100644 --- a/conf/scylla.yaml +++ b/conf/scylla.yaml @@ -855,3 +855,10 @@ rf_rack_valid_keyspaces: false # Maximum number of items in single BatchWriteItem command. Default is 100. # Note: DynamoDB has a hard-coded limit of 25. # alternator_max_items_in_batch_write: 100 + +# +# Vector Store options +# +# Uri for the vector store using dns name. Only http schema is supported. Port number is mandatory. +# Default is empty, which means that the vector store is not used. +# vector_store_uri: http://vector-store.dns.name:{port} diff --git a/configure.py b/configure.py index 04fabc1159..0a916d47ee 100755 --- a/configure.py +++ b/configure.py @@ -567,6 +567,7 @@ scylla_tests = set([ 'test/boost/symmetric_key_test', 'test/boost/types_test', 'test/boost/utf8_test', + 'test/boost/vector_store_client_test', 'test/boost/vint_serialization_test', 'test/boost/virtual_table_mutation_source_test', 'test/boost/wasm_alloc_test', @@ -1219,6 +1220,7 @@ scylla_core = (['message/messaging_service.cc', 'node_ops/task_manager_module.cc', 'reader_concurrency_semaphore_group.cc', 'utils/disk_space_monitor.cc', + 'service/vector_store_client.cc', ] + [Antlr3Grammar('cql3/Cql.g')] \ + scylla_raft_core ) diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 0a5df800e2..4ede4485f4 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -27,6 +27,7 @@ #include "cql3/untyped_result_set.hh" #include "db/config.hh" #include "data_dictionary/data_dictionary.hh" +#include "service/vector_store_client.hh" #include "utils/hashers.hh" #include "utils/error_injection.hh" #include "service/migration_manager.hh" @@ -68,11 +69,12 @@ static service::query_state query_state_for_internal_call() { return {service::client_state::for_internal_calls(), empty_service_permit()}; } -query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm) +query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm) : _migration_subscriber{std::make_unique(this)} , _proxy(proxy) , _db(db) , _mnotifier(mn) + , _vector_store_client(vsc) , _mcfg(mcfg) , _cql_config(cql_cfg) , _prepared_cache(prep_cache_log, _mcfg.prepared_statment_cache_size) diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 13a2242184..b17ef86e98 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -28,6 +28,7 @@ #include "transport/messages/result_message.hh" #include "service/client_state.hh" #include "service/broadcast_tables/experimental/query_result.hh" +#include "service/vector_store_client.hh" #include "utils/assert.hh" #include "utils/observable.hh" #include "service/raft/raft_group0_client.hh" @@ -107,6 +108,7 @@ private: service::storage_proxy& _proxy; data_dictionary::database _db; service::migration_notifier& _mnotifier; + service::vector_store_client& _vector_store_client; memory_config _mcfg; const cql_config& _cql_config; @@ -146,7 +148,7 @@ public: static std::unique_ptr parse_statement(const std::string_view& query, dialect d); static std::vector> parse_statements(std::string_view queries, dialect d); - query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); + query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); ~query_processor(); @@ -176,6 +178,14 @@ public: lang::manager& lang() { return _lang_manager; } + const service::vector_store_client& vector_store_client() const noexcept { + return _vector_store_client; + } + + service::vector_store_client& vector_store_client() noexcept { + return _vector_store_client; + } + db::auth_version_t auth_version; statements::prepared_statement::checked_weak_ptr get_prepared(const std::optional& user, const prepared_cache_key_type& key) { diff --git a/db/config.cc b/db/config.cc index 72d62c3eca..3dc7e7049b 100644 --- a/db/config.cc +++ b/db/config.cc @@ -1361,6 +1361,7 @@ db::config::config(std::shared_ptr exts) // alternator_max_items_in_batch_write matches DynamoDB behaviour of size limit, but with different value - for DynamoDB it's 25 // (see DynamoDB's documentation for BatchWriteItem command) , alternator_max_items_in_batch_write(this, "alternator_max_items_in_batch_write", value_status::Used, 100, "Maximum amount of items in single BatchItemWrite call.") + , vector_store_uri(this, "vector_store_uri", value_status::Used, "", "The URI of the vector store to use for vector search. If not set, vector search is disabled.") , abort_on_ebadf(this, "abort_on_ebadf", value_status::Used, true, "Abort the server on incorrect file descriptor access. Throws exception when disabled.") , redis_port(this, "redis_port", value_status::Used, 0, "Port on which the REDIS transport listens for clients.") , redis_ssl_port(this, "redis_ssl_port", value_status::Used, 0, "Port on which the REDIS TLS native transport listens for clients.") diff --git a/db/config.hh b/db/config.hh index ff97a2f28c..1cc1c0415d 100644 --- a/db/config.hh +++ b/db/config.hh @@ -488,6 +488,8 @@ public: named_value alternator_describe_endpoints; named_value alternator_max_items_in_batch_write; + named_value vector_store_uri; + named_value abort_on_ebadf; named_value redis_port; diff --git a/main.cc b/main.cc index 19dae2099a..929cc2075d 100644 --- a/main.cc +++ b/main.cc @@ -40,6 +40,7 @@ #include "service/migration_manager.hh" #include "service/tablet_allocator.hh" #include "service/load_meter.hh" +#include "service/vector_store_client.hh" #include "service/view_update_backlog_broker.hh" #include "service/qos/service_level_controller.hh" #include "streaming/stream_session.hh" @@ -740,6 +741,7 @@ sharded token_metadata; sharded mapreduce_service; sharded gossiper; sharded snitch; + sharded vector_store_client; // This worker wasn't designed to be used from multiple threads. // If you are attempting to do that, make sure you know what you are doing. @@ -779,7 +781,8 @@ sharded token_metadata; return seastar::async([&app, cfg, ext, &disk_space_monitor_shard0, &cm, &sstm, &db, &qp, &bm, &proxy, &mapreduce_service, &mm, &mm_notifier, &ctx, &opts, &dirs, &prometheus_server, &cf_cache_hitrate_calculator, &load_meter, &feature_service, &gossiper, &snitch, &token_metadata, &erm_factory, &snapshot_ctl, &messaging, &sst_dir_semaphore, &raft_gr, &service_memory_limiter, - &repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker] { + &repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker, + &vector_store_client] { try { if (opts.contains("relabel-config-file") && !opts["relabel-config-file"].as().empty()) { // calling update_relabel_config_from_file can cause an exception that would stop startup @@ -1318,6 +1321,13 @@ sharded token_metadata; static sharded cql_config; cql_config.start(std::ref(*cfg)).get(); + checkpoint(stop_signal, "starting a vector store service"); + vector_store_client.start(std::ref(*cfg)).get(); + auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [&vector_store_client] { + vector_store_client.stop().get(); + }); + vector_store_client.invoke_on_all(&service::vector_store_client::start_background_tasks).get(); + checkpoint(stop_signal, "starting query processor"); cql3::query_processor::memory_config qp_mcfg = {memory::stats().total_memory() / 256, memory::stats().total_memory() / 2560}; debug::the_query_processor = &qp; @@ -1329,7 +1339,7 @@ sharded token_metadata; std::chrono::duration_cast(cql3::prepared_statements_cache::entry_expiry)); auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms()); - qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get(); + qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), std::ref(vector_store_client), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get(); checkpoint(stop_signal, "starting lifecycle notifier"); lifecycle_notifier.start().get(); diff --git a/service/CMakeLists.txt b/service/CMakeLists.txt index 4c672cb3d9..1e1ee39580 100644 --- a/service/CMakeLists.txt +++ b/service/CMakeLists.txt @@ -33,7 +33,8 @@ target_sources(service task_manager_module.cc topology_coordinator.cc topology_mutation.cc - topology_state_machine.cc) + topology_state_machine.cc + vector_store_client.cc) target_include_directories(service PUBLIC ${CMAKE_SOURCE_DIR}) diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc new file mode 100644 index 0000000000..9b1ceb9b09 --- /dev/null +++ b/service/vector_store_client.cc @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "vector_store_client.hh" +#include "db/config.hh" +#include "exceptions/exceptions.hh" +#include +#include + +namespace { + +using configuration_exception = exceptions::configuration_exception; +using host_name = service::vector_store_client::host_name; +using port_number = service::vector_store_client::port_number; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +logging::logger vslogger("vector_store_client"); + +auto parse_port(std::string const& port_txt) -> std::optional { + auto port = port_number{}; + auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port); + if (*ptr != '\0' || ec != std::errc{}) { + return std::nullopt; + } + return port; +} + +auto parse_service_uri(std::string_view uri) -> std::optional> { + constexpr auto URI_REGEX = R"(^http:\/\/([a-z0-9._-]+):([0-9]+)$)"; + auto const uri_regex = std::regex(URI_REGEX); + auto uri_match = std::smatch{}; + auto uri_txt = std::string(uri); + if (!std::regex_match(uri_txt, uri_match, uri_regex) || uri_match.size() != 3) { + return {}; + } + auto host = uri_match[1].str(); + auto port = parse_port(uri_match[2].str()); + if (!port) { + return {}; + } + return {{host, *port}}; +} + +} // namespace + +namespace service { + +struct vector_store_client::impl { + host_name host; + port_number port{}; + gate tasks_gate; + + impl(host_name host_, port_number port_) + : host(std::move(host_)) + , port(port_) { + } +}; + +vector_store_client::vector_store_client(config const& cfg) { + auto config_uri = cfg.vector_store_uri(); + if (config_uri.empty()) { + vslogger.info("Vector Store service URI is not configured."); + return; + } + + auto parsed_uri = parse_service_uri(config_uri); + if (!parsed_uri) { + throw configuration_exception(format("Invalid Vector Store service URI: {}", config_uri)); + } + + auto [host, port] = *parsed_uri; + _impl = std::make_unique(std::move(host), port); + vslogger.info("Vector Store service uri = {}:{}.", _impl->host, _impl->port); +} + +vector_store_client::~vector_store_client() = default; + +void vector_store_client::start_background_tasks() { + if (is_disabled()) { + return; + } +} + +auto vector_store_client::stop() -> future<> { + if (is_disabled()) { + co_return; + } + co_await _impl->tasks_gate.close(); +} + +auto vector_store_client::host() const -> std::expected { + if (is_disabled()) { + return std::unexpected{disabled{}}; + } + return {_impl->host}; +} + +auto vector_store_client::port() const -> std::expected { + if (is_disabled()) { + return std::unexpected{disabled{}}; + } + return {_impl->port}; +} + +} // namespace service + diff --git a/service/vector_store_client.hh b/service/vector_store_client.hh new file mode 100644 index 0000000000..00fcae2a1e --- /dev/null +++ b/service/vector_store_client.hh @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "seastarx.hh" +#include +#include +#include + +namespace db { +class config; +} + +namespace service { + +/// A client with the vector-store service. +class vector_store_client final { + struct impl; + std::unique_ptr _impl; + +public: + using config = db::config; + using host_name = sstring; + using port_number = std::uint16_t; + + /// The vector_store_client service is disabled. + struct disabled {}; + + explicit vector_store_client(config const& cfg); + ~vector_store_client(); + + /// Start background tasks. + void start_background_tasks(); + + /// Stop the service. + auto stop() -> future<>; + + /// Check if the vector_store_client is disabled. + auto is_disabled() const { + return !bool{_impl}; + } + + /// Get the current host name. + [[nodiscard]] auto host() const -> std::expected; + + /// Get the current port number. + [[nodiscard]] auto port() const -> std::expected; + +private: + friend struct vector_store_client_tester; +}; + +/// A tester for the vector_store_client, used for testing purposes. +struct vector_store_client_tester {}; + +} // namespace service + diff --git a/test/boost/CMakeLists.txt b/test/boost/CMakeLists.txt index 1fb8106a07..06bcb078f3 100644 --- a/test/boost/CMakeLists.txt +++ b/test/boost/CMakeLists.txt @@ -286,6 +286,8 @@ add_scylla_test(wrapping_interval_test KIND BOOST) add_scylla_test(address_map_test KIND SEASTAR) +add_scylla_test(vector_store_client_test + KIND SEASTAR) add_scylla_test(combined_tests KIND SEASTAR diff --git a/test/boost/vector_store_client_test.cc b/test/boost/vector_store_client_test.cc new file mode 100644 index 0000000000..abb3355485 --- /dev/null +++ b/test/boost/vector_store_client_test.cc @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "service/vector_store_client.hh" +#include "db/config.hh" +#include "exceptions/exceptions.hh" +#include +#include +#include +#include +#include +#include +#include + + +namespace { + +using namespace seastar; + +using vector_store_client = service::vector_store_client; +using vector_store_client_tester = service::vector_store_client_tester; +using config = vector_store_client::config; +using configuration_exception = exceptions::configuration_exception; + +} // namespace + +BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) { + { + auto cfg = config(); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(vs.is_disabled()); + BOOST_CHECK(!vs.host()); + BOOST_CHECK(!vs.port()); + } + { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.com:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + BOOST_CHECK_EQUAL(*vs.host(), "good.authority.com"); + BOOST_CHECK_EQUAL(*vs.port(), 6080); + } + { + auto cfg = config(); + cfg.vector_store_uri.set("http://bad,authority.com:6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("bad-schema://authority.com:6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.port.com:a6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.port.com:60806080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.format.com:60:80"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://authority.com:6080/bad/path"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + } +} + diff --git a/test/lib/cql_test_env.cc b/test/lib/cql_test_env.cc index bd4bccda23..b12a5caa51 100644 --- a/test/lib/cql_test_env.cc +++ b/test/lib/cql_test_env.cc @@ -170,6 +170,7 @@ private: sharded _gossip_address_map; sharded _fd_pinger; sharded _cdc; + sharded _vector_store_client; db::config* _db_config; service::raft_group0_client* _group0_client; @@ -704,7 +705,12 @@ private: std::chrono::duration_cast(cql3::prepared_statements_cache::entry_expiry)); auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms()); - _qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get(); + _vector_store_client.start(std::ref(*cfg)).get(); + auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [this] { + _vector_store_client.stop().get(); + }); + + _qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), std::ref(_vector_store_client), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get(); auto stop_qp = defer_verbose_shutdown("query processor", [this] { _qp.stop().get(); }); _elc_notif.start().get(); From 8d3c33f74aa3b9d1bde9b8e5afa78de255df4907 Mon Sep 17 00:00:00 2001 From: Pawel Pery Date: Wed, 2 Jul 2025 19:26:32 +0200 Subject: [PATCH 2/5] utils: refactor sequential_producer as abortable This patch is a part of vector_store_client sharded service implementation for a communication with vector-store service. There is a need for abortable sequention_producer operator(). The existing operator() is changed to allow timeout argument with default time_point::max() (as current default usage) and the new operator() is created with abort_source parameter. Reference: VS-47 --- utils/sequential_producer.hh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/utils/sequential_producer.hh b/utils/sequential_producer.hh index a075781510..853129f490 100644 --- a/utils/sequential_producer.hh +++ b/utils/sequential_producer.hh @@ -21,6 +21,7 @@ template class sequential_producer { public: using factory_t = std::function()>; + using time_point = seastar::shared_future::time_point; private: factory_t _factory; @@ -32,11 +33,18 @@ class sequential_producer { clear(); } - seastar::future operator()() { + seastar::future operator()(time_point timeout = time_point::max()) { if (_churning.available()) { _churning = _factory(); } - return _churning.get_future(); + return _churning.get_future(timeout); + } + + seastar::future operator()(seastar::abort_source& as) { + if (_churning.available()) { + _churning = _factory(); + } + return _churning.get_future(as); } void clear() { From 1f797e2fcd133f91da9e56785e29cbbe173576ac Mon Sep 17 00:00:00 2001 From: Pawel Pery Date: Wed, 2 Jul 2025 19:26:32 +0200 Subject: [PATCH 3/5] vector_store_client: implement ip addr retrieval from dns This patch is a part of vector_store_client sharded service implementation for a communication with vector-store service. It implements functionality for refreshing ip address of the vector-store service dns name and creating a new HTTP client with that address. It also provides cleanup of unused http clients. There are hardcoded intervals for dns refresh and old http clients cleanup, and timeout for requesting new http client. This patch introduces two background tasks - for dns resolving task and for cleanup old http clients. It adds unit tests for possible dns refreshing issues. Reference: VS-47 Fixes: VS-45 --- service/vector_store_client.cc | 257 ++++++++++++++++++++++++- service/vector_store_client.hh | 18 +- test/boost/vector_store_client_test.cc | 182 +++++++++++++++++ 3 files changed, 455 insertions(+), 2 deletions(-) diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc index 9b1ceb9b09..0f013082ba 100644 --- a/service/vector_store_client.cc +++ b/service/vector_store_client.cc @@ -9,14 +9,35 @@ #include "vector_store_client.hh" #include "db/config.hh" #include "exceptions/exceptions.hh" +#include "utils/sequential_producer.hh" #include +#include #include +#include +#include +#include +#include +#include namespace { using configuration_exception = exceptions::configuration_exception; +using duration = lowres_clock::duration; using host_name = service::vector_store_client::host_name; +using http_client = http::experimental::client; +using inet_address = seastar::net::inet_address; +using milliseconds = std::chrono::milliseconds; using port_number = service::vector_store_client::port_number; +using time_point = lowres_clock::time_point; + +// Wait time before retrying after an exception occurred +constexpr auto EXCEPTION_OCCURED_WAIT = std::chrono::seconds(5); + +// Minimum interval between dns name refreshes +constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5); + +/// Timeout for waiting for a new client to be available +constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) logging::logger vslogger("vector_store_client"); @@ -46,18 +67,203 @@ auto parse_service_uri(std::string_view uri) -> std::optional future { + auto result = co_await coroutine::as_future(sleep_abortable(timeout, as)); + if (result.failed()) { + auto err = result.get_exception(); + if (as.abort_requested()) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + +/// Wait for a condition variable to be signaled or timeout. +auto wait_for_signal(condition_variable& cv, time_point timeout) -> future { + auto result = co_await coroutine::as_future(cv.wait(timeout)); + if (result.failed()) { + auto err = result.get_exception(); + if (try_catch(err) != nullptr) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + } // namespace namespace service { struct vector_store_client::impl { + lw_shared_ptr current_client; + std::vector> old_clients; host_name host; port_number port{}; + inet_address addr; + time_point last_dns_refresh; gate tasks_gate; + condition_variable refresh_cv; + condition_variable refresh_client_cv; + abort_source abort_refresh; + milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL; + milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT; + std::function>(sstring const&)> dns_resolver; + sequential_producer> client_producer; impl(host_name host_, port_number port_) : host(std::move(host_)) - , port(port_) { + , port(port_) + , dns_resolver([](auto const& host) -> future> { + auto addr = co_await coroutine::as_future(net::dns::resolve_name(host)); + if (addr.failed()) { + auto err = addr.get_exception(); + if (try_catch(err) != nullptr) { + co_return std::nullopt; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return co_await std::move(addr); + }) + , client_producer([&]() -> future> { + trigger_dns_refresh(); + co_await wait_for_signal(refresh_client_cv, lowres_clock::now() + wait_for_client_timeout); + co_return current_client; + }) { + } + + /// Refresh the http client with a new address resolved from the DNS name. + /// If the DNS resolution fails, the current client is set to nullptr. + /// If the address is the same as the current one, do nothing. + /// Old clients are saved for later cleanup in a specific task. + auto refresh_addr() -> future<> { + auto new_addr = co_await dns_resolver(host); + if (!new_addr) { + current_client = nullptr; + co_return; + } + + // Check if the new address is the same as the current one + if (current_client && *new_addr == addr) { + co_return; + } + + addr = *new_addr; + old_clients.emplace_back(current_client); + current_client = make_lw_shared(socket_address(addr, port)); + } + + /// A task for refreshing the vector store http client. + auto refresh_addr_task() -> future<> { + for (;;) { + auto exception_occured = false; + try { + if (abort_refresh.abort_requested()) { + break; + } + + // Do not refresh the service address too often + auto now = lowres_clock::now(); + auto current_duration = now - last_dns_refresh; + if (current_duration > dns_refresh_interval) { + last_dns_refresh = now; + co_await refresh_addr(); + } else { + // Wait till the end of the refreshing interval + if (co_await wait_for_timeout(dns_refresh_interval - current_duration, abort_refresh)) { + continue; + } + // If the wait was aborted, we stop refreshing + break; + } + + if (abort_refresh.abort_requested()) { + break; + } + + // new client is available + refresh_client_cv.broadcast(); + + co_await cleanup_old_clients(); + + co_await refresh_cv.when(); + } catch (const std::exception& e) { + vslogger.error("Vector Store Client refresh task failed: {}", e.what()); + exception_occured = true; + } catch (...) { + vslogger.error("Vector Store Client refresh task failed with unknown exception"); + exception_occured = true; + } + if (exception_occured) { + // If an exception occurred, we wait for the next signal to refresh the address + co_await wait_for_timeout(EXCEPTION_OCCURED_WAIT, abort_refresh); + } + } + + co_await cleanup_old_clients(); + co_await cleanup_current_client(); + } + + /// Request a DNS refresh in the specific task. + void trigger_dns_refresh() { + refresh_cv.signal(); + } + + /// Cleanup current client + auto cleanup_current_client() -> future<> { + if (current_client) { + co_await current_client->close(); + } + current_client = nullptr; + } + + /// Cleanup old clients that are no longer used. + auto cleanup_old_clients() -> future<> { + // iterate over old clients and close them. There is a co_await in the loop + // so we need to use [] accessor and copying clients to avoid dangling references of iterators. + // NOLINTNEXTLINE(modernize-loop-convert) + for (auto it = 0U; it < old_clients.size(); ++it) { + auto& client = old_clients[it]; + if (client && client.owned()) { + auto client_cloned = client; + co_await client_cloned->close(); + client_cloned = nullptr; + } + } + std::erase_if(old_clients, [](auto const& client) { + return !client; + }); + } + + struct get_client_response { + lw_shared_ptr client; ///< The http client. + host_name host; ///< The host name for the vector-store service. + }; + + using get_client_error = std::variant; + + /// Get the current http client or wait for a new one to be available. + auto get_client(abort_source& as) -> future> { + if (current_client) { + co_return get_client_response{.client = current_client, .host = host}; + } + + auto current_client = co_await coroutine::as_future(client_producer(as)); + + if (current_client.failed()) { + auto err = current_client.get_exception(); + if (as.abort_requested()) { + co_return std::unexpected{aborted{}}; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + auto client = co_await std::move(current_client); + if (!client) { + co_return std::unexpected{addr_unavailable{}}; + } + co_return get_client_response{.client = client, .host = host}; } }; @@ -84,12 +290,22 @@ void vector_store_client::start_background_tasks() { if (is_disabled()) { return; } + + /// start the background task to refresh the service address + (void)try_with_gate(_impl->tasks_gate, [this] { + return _impl->refresh_addr_task(); + }).handle_exception([](std::exception_ptr eptr) { + on_internal_error_noexcept(vslogger, format("The Vector Store Client refresh task failed: {}", eptr)); + }); } auto vector_store_client::stop() -> future<> { if (is_disabled()) { co_return; } + + _impl->abort_refresh.request_abort(); + _impl->refresh_cv.signal(); co_await _impl->tasks_gate.close(); } @@ -107,5 +323,44 @@ auto vector_store_client::port() const -> std::expected { return {_impl->port}; } +void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client"); + } + vsc._impl->dns_refresh_interval = interval; +} + +void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set wait_for_client_timeout on a disabled vector store client"); + } + vsc._impl->wait_for_client_timeout = timeout; +} + +void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client"); + } + vsc._impl->dns_resolver = std::move(resolver); +} + +void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot trigger a dns resolver on a disabled vector store client"); + } + vsc._impl->trigger_dns_refresh(); +} + +auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future> { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot check hostname resolving on a disabled vector store client"); + } + auto client_host = co_await vsc._impl->get_client(as); + if (!client_host) { + co_return std::nullopt; + } + co_return vsc._impl->addr; +} + } // namespace service diff --git a/service/vector_store_client.hh b/service/vector_store_client.hh index 00fcae2a1e..f9aa5cdb09 100644 --- a/service/vector_store_client.hh +++ b/service/vector_store_client.hh @@ -17,6 +17,10 @@ namespace db { class config; } +namespace seastar::net { +class inet_address; +} + namespace service { /// A client with the vector-store service. @@ -32,6 +36,12 @@ public: /// The vector_store_client service is disabled. struct disabled {}; + /// The operation was aborted. + struct aborted {}; + + /// The vector-store addr is unavailable (not possible to get an addr from the dns service). + struct addr_unavailable {}; + explicit vector_store_client(config const& cfg); ~vector_store_client(); @@ -57,7 +67,13 @@ private: }; /// A tester for the vector_store_client, used for testing purposes. -struct vector_store_client_tester {}; +struct vector_store_client_tester { + static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval); + static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout); + static void set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver); + static void trigger_dns_resolver(vector_store_client& vsc); + static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future>; +}; } // namespace service diff --git a/test/boost/vector_store_client_test.cc b/test/boost/vector_store_client_test.cc index abb3355485..58f64803ba 100644 --- a/test/boost/vector_store_client_test.cc +++ b/test/boost/vector_store_client_test.cc @@ -26,6 +26,23 @@ using vector_store_client = service::vector_store_client; using vector_store_client_tester = service::vector_store_client_tester; using config = vector_store_client::config; using configuration_exception = exceptions::configuration_exception; +using inet_address = seastar::net::inet_address; +using milliseconds = std::chrono::milliseconds; +using port_number = vector_store_client::port_number; + +auto repeat_until(milliseconds timeout, std::function()> func) -> future { + auto begin = lowres_clock::now(); + while (!co_await func()) { + if (lowres_clock::now() - begin > timeout) { + co_return false; + } + } + co_return true; +} + +auto print_addr(const inet_address& addr) -> sstring { + return format("{}", addr); +} } // namespace @@ -62,3 +79,168 @@ BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) { } } +/// Resolving of the hostname is started in start_background_tasks() +SEASTAR_TEST_CASE(vector_store_client_test_dns_started) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + + co_await vs.stop(); +} + +/// Unable to resolve the hostname +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolve_failure) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return std::nullopt; + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(!co_await vector_store_client_tester::resolve_hostname(vs, as)); + + co_await vs.stop(); +} + +/// Resolving of the hostname is repeated after errors +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolving_repeated) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(20)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + if (count % 3 != 0) { + co_return std::nullopt; + } + co_return inet_address(format("127.0.0.{}", count)); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 3); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.3"); + + vector_store_client_tester::trigger_dns_resolver(vs); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return !co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 6); + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.6"); + + co_await vs.stop(); +} + +/// Minimal interval between DNS refreshes is respected +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_respects_interval) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + co_await sleep(std::chrono::milliseconds(20)); // wait for the first DNS refresh + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_EQUAL(count, 1); + count = 0; + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + co_await sleep(std::chrono::milliseconds(100)); // wait for the next DNS refresh + + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_GE(count, 1); + BOOST_CHECK_LE(count, 2); + + co_await vs.stop(); +} + +/// DNS refresh could be aborted +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_await sleep(std::chrono::milliseconds(100)); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto timeout = timer([&as]() { + as.request_abort(); + }); + timeout.arm(std::chrono::milliseconds(10)); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_CHECK(!addr); + + co_await vs.stop(); +} + From 5bfce5290eb21e442ff9891e656e6824988ad90a Mon Sep 17 00:00:00 2001 From: Pawel Pery Date: Wed, 2 Jul 2025 19:26:32 +0200 Subject: [PATCH 4/5] cql3: refactor primary_key as a top-level class This patch is a part of vector_store_client sharded service implementation for a communication with vector-store service. There is a need for forward declaration of primary_key class. This patch moves a nested definition of select_statement::primary_key (from a cql3::statements namespace) into a standalone class in a cql3::statements namespace. Reference: VS-47 --- cql3/statements/select_statement.cc | 4 ++-- cql3/statements/select_statement.hh | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index dce13d92ae..6ae52c1c97 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -1471,10 +1471,10 @@ indexed_table_select_statement::find_index_partition_ranges(query_processor& qp, // Note: the partitions keys returned by this function are sorted // in token order. See issue #3423. -future, lw_shared_ptr>>> +future, lw_shared_ptr>>> indexed_table_select_statement::find_index_clustering_rows(query_processor& qp, service::query_state& state, const query_options& options) const { - using value_type = std::tuple, lw_shared_ptr>; + using value_type = std::tuple, lw_shared_ptr>; auto now = gc_clock::now(); auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options); const uint64_t limit = get_inner_loop_limit(get_limit(options, _limit), _selection->is_aggregate()); diff --git a/cql3/statements/select_statement.hh b/cql3/statements/select_statement.hh index 9593b1fccd..03c760c1c6 100644 --- a/cql3/statements/select_statement.hh +++ b/cql3/statements/select_statement.hh @@ -43,6 +43,13 @@ namespace restrictions { namespace statements { + +/// Encapsulates a partition key and clustering key prefix as a primary key. +struct primary_key { + dht::decorated_key partition; + clustering_key_prefix clustering; +}; + /** * Encapsulates a completely parsed SELECT query, including the target * column family, expression, result count, and ordering clause. @@ -134,12 +141,6 @@ public: const query_options& options, gc_clock::time_point now, int32_t page_size, bool aggregate, bool nonpaged_filtering, uint64_t limit, std::optional cas_shard) const; - - struct primary_key { - dht::decorated_key partition; - clustering_key_prefix clustering; - }; - future> process_results(foreign_ptr> results, lw_shared_ptr cmd, const query_options& options, gc_clock::time_point now) const; From eadbf69d6f5cf7abdd86f3b5074577b8e3ffc724 Mon Sep 17 00:00:00 2001 From: Pawel Pery Date: Wed, 2 Jul 2025 19:26:32 +0200 Subject: [PATCH 5/5] vector_store_client: implement ANN API This patch is a part of vector_store_client sharded service implementation for a communication with vector-store service. It implements a functionality for ANN search request to a vector-store service. It sends request, receive response and after parsing it returns the list of primary keys. It adds json parsing functionality specific for the HTTP ANN API. It adds a hardcoded http request timeout for retrieving response from the Vector Store service. It also adds an automatic boost test of the ANN search interface, which uses a mockup http server in a background to simulate vector-store service. It adds a documentation for HTTP API protocol used used for ANN functionality. Fixes: VS-47 --- docs/dev/protocols.md | 45 ++++ service/vector_store_client.cc | 206 +++++++++++++++++ service/vector_store_client.hh | 33 +++ test/boost/vector_store_client_test.cc | 294 +++++++++++++++++++++++++ 4 files changed, 578 insertions(+) diff --git a/docs/dev/protocols.md b/docs/dev/protocols.md index ce413b2566..205aef2f5f 100644 --- a/docs/dev/protocols.md +++ b/docs/dev/protocols.md @@ -276,3 +276,48 @@ to Scylla's REST API port over the loopback (localhost) interface. The port on which scylla-jmx listens is by default port 7199. This port, and the listen address, can be overridden with the `-jp` and `-ja` options (respectively) of the `scylla-jmx` script. + +## Vector Store + +The Vector Search functionality within Scylla is provided by an external +service [vector-store](https://github.com/scylla/vector-store). That service is +responsible for creating and managing vector indexes built from data retrieved +from Scylla using the CQL protocol and CDC functionality from a table and a +custom index created in Scylla using `CREATE INDEX {keyspace}.{index} ON +{table}({embedding) USING 'vector_index'`. Scylla does a vector search by +delegating the search to the vector-store service using the HTTP API protocol. +Scylla is the HTTP client of the vector-store service. + +The supported vector-store HTTP API: + +### `POST /api/v1/indexes/{keyspace}/{index}/ann` + +This endpoint is for an ANN (Approximate Nearest Neighbor) search. + +Parameters: +- `keyspace`: The keyspace name of the index to search. +- `index`: The index name to search. + +Request Body: +```json +{ + "embedding": [0.1, 0.2, 0.3, ...], // The vector to search for. + "limit": 10 // The number of nearest neighbors to return. +} +``` + +Responses: + +- 200 OK: Returns the nearest neighbors found. + + Response Body (Structure of Arrays): + ```json + { + "distances": [0.1234, 0.5678, ...], // The distances to the nearest neighbors up to the limit provided by a request. + "primary_keys": { + "pk1_column_name": ["value1", "value2", ...], // The primary key values of the pk1_column_name with same size as "distances". + "ck1_column_name": ["value1", "value2", ...], // The primary key values of the ck1_column_name with same size as "distances". + } + } + ``` + diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc index 0f013082ba..63d930ce05 100644 --- a/service/vector_store_client.cc +++ b/service/vector_store_client.cc @@ -7,27 +7,46 @@ */ #include "vector_store_client.hh" +#include "cql3/statements/select_statement.hh" +#include "cql3/type_json.hh" #include "db/config.hh" #include "exceptions/exceptions.hh" #include "utils/sequential_producer.hh" +#include "dht/i_partitioner.hh" +#include "keys.hh" +#include "utils/rjson.hh" +#include "schema/schema.hh" #include #include +#include #include #include #include #include +#include #include #include +#include +#include namespace { +using ann_error = service::vector_store_client::ann_error; using configuration_exception = exceptions::configuration_exception; using duration = lowres_clock::duration; +using embedding = service::vector_store_client::embedding; +using limit = service::vector_store_client::limit; using host_name = service::vector_store_client::host_name; using http_client = http::experimental::client; +using http_path = sstring; using inet_address = seastar::net::inet_address; +using json_content = sstring; using milliseconds = std::chrono::milliseconds; +using operation_type = httpd::operation_type; using port_number = service::vector_store_client::port_number; +using primary_key = service::vector_store_client::primary_key; +using primary_keys = service::vector_store_client::primary_keys; +using service_reply_format_error = service::vector_store_client::service_reply_format_error; using time_point = lowres_clock::time_point; // Wait time before retrying after an exception occurred @@ -39,6 +58,9 @@ constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5); /// Timeout for waiting for a new client to be available constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5); +/// How many retries to do for HTTP requests +constexpr auto HTTP_REQUEST_RETRIES = 3; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) logging::logger vslogger("vector_store_client"); @@ -93,6 +115,97 @@ auto wait_for_signal(condition_variable& cv, time_point timeout) -> future co_return true; } +auto get_key_column_value(const rjson::value& item, std::size_t idx, const column_definition& column) -> std::expected { + auto const& column_name = column.name_as_text(); + auto const* keys_obj = rjson::find(item, column_name); + if (keys_obj == nullptr) { + vslogger.error("Vector Store returned invalid JSON: missing key column '{}'", column_name); + return std::unexpected{service_reply_format_error{}}; + } + if (!keys_obj->IsArray()) { + vslogger.error("Vector Store returned invalid JSON: key column '{}' is not an array", column_name); + return std::unexpected{service_reply_format_error{}}; + } + auto const& keys_arr = keys_obj->GetArray(); + if (keys_arr.Size() <= idx) { + vslogger.error("Vector Store returned invalid JSON: key column '{}' array too small", column_name); + return std::unexpected{service_reply_format_error{}}; + } + auto const& key = keys_arr[idx]; + return from_json_object(*column.type, key); +} + +auto pk_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected { + std::vector raw_pk; + for (const column_definition& cdef : schema->partition_key_columns()) { + auto raw_value = get_key_column_value(item, idx, cdef); + if (!raw_value) { + return std::unexpected{raw_value.error()}; + } + raw_pk.emplace_back(*raw_value); + } + return partition_key::from_exploded(raw_pk); +} + +auto ck_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected { + if (schema->clustering_key_size() == 0) { + return clustering_key_prefix::make_empty(); + } + + std::vector raw_ck; + for (const column_definition& cdef : schema->clustering_key_columns()) { + auto raw_value = get_key_column_value(item, idx, cdef); + if (!raw_value) { + return std::unexpected{raw_value.error()}; + } + raw_ck.emplace_back(*raw_value); + } + + return clustering_key_prefix::from_exploded(raw_ck); +} + +auto write_ann_json(embedding embedding, limit limit) -> json_content { + return seastar::format(R"({{"embedding":[{}],"limit":{}}})", fmt::join(embedding, ","), limit); +} + +auto read_ann_json(rjson::value const& json, schema_ptr const& schema) -> std::expected { + if (!json.HasMember("primary_keys")) { + vslogger.error("Vector Store returned invalid JSON: missing 'primary_keys'"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& keys_json = json["primary_keys"]; + if (!keys_json.IsObject()) { + vslogger.error("Vector Store returned invalid JSON: 'primary_keys' is not an object"); + return std::unexpected{service_reply_format_error{}}; + } + + if (!json.HasMember("distances")) { + vslogger.error("Vector Store returned invalid JSON: missing 'distances'"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& distances_json = json["distances"]; + if (!distances_json.IsArray()) { + vslogger.error("Vector Store returned invalid JSON: 'distances' is not an array"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& distances_arr = json["distances"].GetArray(); + + auto size = distances_arr.Size(); + auto keys = primary_keys{}; + for (auto idx = 0U; idx < size; ++idx) { + auto pk = pk_from_json(keys_json, idx, schema); + if (!pk) { + return std::unexpected{pk.error()}; + } + auto ck = ck_from_json(keys_json, idx, schema); + if (!ck) { + return std::unexpected{ck.error()}; + } + keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck}); + } + return std::move(keys); +} + } // namespace namespace service { @@ -110,6 +223,7 @@ struct vector_store_client::impl { abort_source abort_refresh; milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL; milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT; + unsigned http_request_retries = HTTP_REQUEST_RETRIES; std::function>(sstring const&)> dns_resolver; sequential_producer> client_producer; @@ -265,6 +379,59 @@ struct vector_store_client::impl { } co_return get_client_response{.client = client, .host = host}; } + + struct make_request_response { + http::reply::status_type status; ///< The HTTP status of the response. + std::vector> content; ///< The content of the response. + }; + + using make_request_error = std::variant; + + auto make_request(operation_type method, http_path path, std::optional content, abort_source& as) + -> future> { + auto resp = make_request_response{.status = http::reply::status_type::ok, .content = std::vector>()}; + + for (auto retries = 0; retries < HTTP_REQUEST_RETRIES; ++retries) { + auto client_host = co_await get_client(as); + if (!client_host) { + co_return std::unexpected{std::visit( + [](auto&& err) { + return make_request_error{err}; + }, + client_host.error())}; + } + auto [client, host] = *std::move(client_host); + + auto req = http::request::make(method, host, path); + if (content) { + req.write_body("json", *content); + } + + auto result = co_await coroutine::as_future(client->make_request( + std::move(req), + [&resp](http::reply const& reply, input_stream body) -> future<> { + resp.status = reply._status; + resp.content = co_await util::read_entire_stream(body); + }, + std::nullopt, &as)); + if (result.failed()) { + auto err = result.get_exception(); + if (as.abort_requested()) { + co_return std::unexpected{aborted{}}; + } + if (try_catch(err) == nullptr) { + co_await coroutine::return_exception_ptr(std::move(err)); + } + // std::system_error means that the server is unavailable, so we retry + } else { + co_return resp; + } + + trigger_dns_refresh(); + } + + co_return std::unexpected{service_unavailable{}}; + } }; vector_store_client::vector_store_client(config const& cfg) { @@ -323,6 +490,38 @@ auto vector_store_client::port() const -> std::expected { return {_impl->port}; } +auto vector_store_client::ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as) + -> future> { + if (is_disabled()) { + vslogger.error("Disabled Vector Store while calling ann"); + co_return std::unexpected{disabled{}}; + } + + auto path = format("/api/v1/indexes/{}/{}/ann", keyspace, name); + auto content = write_ann_json(std::move(embedding), limit); + + auto resp = co_await _impl->make_request(operation_type::POST, std::move(path), std::move(content), as); + if (!resp) { + co_return std::unexpected{std::visit( + [](auto&& err) { + return ann_error{err}; + }, + resp.error())}; + } + + if (resp->status != status_type::ok) { + vslogger.error("Vector Store returned error: HTTP status {}: {}", resp->status, resp->content); + co_return std::unexpected{service_error{resp->status}}; + } + + try { + co_return read_ann_json(rjson::parse(std::move(resp->content)), schema); + } catch (const rjson::error& e) { + vslogger.error("Vector Store returned invalid JSON: {}", e.what()); + co_return std::unexpected{service_reply_format_error{}}; + } +} + void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) { if (vsc.is_disabled()) { on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client"); @@ -337,6 +536,13 @@ void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client vsc._impl->wait_for_client_timeout = timeout; } +void vector_store_client_tester::set_http_request_retries(vector_store_client& vsc, unsigned retries) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set http_request_retries on a disabled vector store client"); + } + vsc._impl->http_request_retries = retries; +} + void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver) { if (vsc.is_disabled()) { on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client"); diff --git a/service/vector_store_client.hh b/service/vector_store_client.hh index f9aa5cdb09..ac6d26da77 100644 --- a/service/vector_store_client.hh +++ b/service/vector_store_client.hh @@ -11,8 +11,15 @@ #include "seastarx.hh" #include #include +#include #include +class schema; + +namespace cql3::statements { +class primary_key; +} + namespace db { class config; } @@ -30,8 +37,16 @@ class vector_store_client final { public: using config = db::config; + using embedding = std::vector; using host_name = sstring; + using index_name = sstring; + using keyspace_name = sstring; + using limit = std::size_t; using port_number = std::uint16_t; + using primary_key = cql3::statements::primary_key; + using primary_keys = std::vector; + using schema_ptr = lw_shared_ptr; + using status_type = http::reply::status_type; /// The vector_store_client service is disabled. struct disabled {}; @@ -42,6 +57,19 @@ public: /// The vector-store addr is unavailable (not possible to get an addr from the dns service). struct addr_unavailable {}; + /// The vector-store service is unavailable. + struct service_unavailable {}; + + /// The error from the vector-store service. + struct service_error { + status_type status; ///< The HTTP status code from the vector-store service. + }; + + /// An unsupported reply format from the vector-store service. + struct service_reply_format_error {}; + + using ann_error = std::variant; + explicit vector_store_client(config const& cfg); ~vector_store_client(); @@ -62,6 +90,10 @@ public: /// Get the current port number. [[nodiscard]] auto port() const -> std::expected; + /// Request the vector store service for the primary keys of the nearest neighbors + auto ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as) + -> future>; + private: friend struct vector_store_client_tester; }; @@ -70,6 +102,7 @@ private: struct vector_store_client_tester { static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval); static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout); + static void set_http_request_retries(vector_store_client& vsc, unsigned retries); static void set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver); static void trigger_dns_resolver(vector_store_client& vsc); static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future>; diff --git a/test/boost/vector_store_client_test.cc b/test/boost/vector_store_client_test.cc index 58f64803ba..ee15e4bbb7 100644 --- a/test/boost/vector_store_client_test.cc +++ b/test/boost/vector_store_client_test.cc @@ -9,13 +9,22 @@ #include "service/vector_store_client.hh" #include "db/config.hh" #include "exceptions/exceptions.hh" +#include "cql3/statements/select_statement.hh" +#include "test/lib/cql_test_env.hh" +#include "test/lib/log.hh" +#include #include #include +#include +#include +#include +#include #include #include #include #include #include +#include namespace { @@ -27,8 +36,46 @@ using vector_store_client_tester = service::vector_store_client_tester; using config = vector_store_client::config; using configuration_exception = exceptions::configuration_exception; using inet_address = seastar::net::inet_address; +using function_handler = httpd::function_handler; +using http_server = httpd::http_server; +using http_server_tester = httpd::http_server_tester; using milliseconds = std::chrono::milliseconds; +using operation_type = httpd::operation_type; using port_number = vector_store_client::port_number; +using reply = http::reply; +using request = http::request; +using routes = httpd::routes; +using status_type = http::reply::status_type; +using url = httpd::url; + +constexpr auto const* LOCALHOST = "127.0.0.1"; + +/// Generate an ephemeral port number for listening on localhost. +/// After closing this socket, the port should be not listened on for a while. +/// This is not guaranteed to be a robust solution, but it should work for most tests. +auto generate_unavailable_localhost_port() -> port_number { + auto inaddr = net::inet_address(LOCALHOST); + auto server = listen(socket_address(inaddr, 0)); + auto port = server.local_address().port(); + server.abort_accept(); + return port; +} + +auto listen_on_ephemeral_port(std::unique_ptr server) -> future, socket_address>> { + auto inaddr = net::inet_address(LOCALHOST); + auto const addr = socket_address(inaddr, 0); + co_await server->listen(addr); + auto const& listeners = http_server_tester::listeners(*server); + BOOST_CHECK_EQUAL(listeners.size(), 1); + co_return std::make_tuple(std::move(server), listeners[0].local_address().port()); +} + +auto new_http_server(std::function set_routes) -> future, socket_address>> { + auto server = std::make_unique("test_vector_store_client"); + set_routes(server->_routes); + server->set_content_streaming(true); + co_return co_await listen_on_ephemeral_port(std::move(server)); +} auto repeat_until(milliseconds timeout, std::function()> func) -> future { auto begin = lowres_clock::now(); @@ -244,3 +291,250 @@ SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) { co_await vs.stop(); } +SEASTAR_TEST_CASE(vector_store_client_ann_test_disabled) { + co_await do_with_cql_env([](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_addr_unavailable) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set("http://bad.authority.here:6080"); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "bad.authority.here"); + co_return std::nullopt; + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_service_unavailable) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port())); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_service_aborted) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port())); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_await sleep(std::chrono::milliseconds(100)); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto timeout = timer([&as]() { + as.request_abort(); + }); + timeout.arm(std::chrono::milliseconds(10)); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + + +SEASTAR_TEST_CASE(vector_store_client_test_ann_request) { + auto ann_replies = make_lw_shared>>(); + auto [server, addr] = co_await new_http_server([ann_replies](routes& r) { + auto ann = [ann_replies](std::unique_ptr req, std::unique_ptr rep) -> future> { + BOOST_REQUIRE(!ann_replies->empty()); + auto [req_exp, rep_inp] = ann_replies->front(); + auto const req_inp = co_await util::read_entire_stream_contiguous(*req->content_stream); + BOOST_CHECK_EQUAL(req_inp, req_exp); + ann_replies->pop(); + rep->set_status(status_type::ok); + rep->write_body("json", rep_inp); + co_return rep; + }; + r.add(operation_type::POST, url("/api/v1/indexes/ks/idx").remainder("ann"), new function_handler(ann, "json")); + }); + + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", addr.port())); + co_await do_with_cql_env( + [&ann_replies](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + // set the wrong idx (wrong endpoint) - service should return 404 + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx2", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + auto* err = std::get_if(&keys.error()); + BOOST_CHECK(err != nullptr); + BOOST_CHECK_EQUAL(err->status, status_type::not_found); + + // missing primary_keys in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys1":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + auto const now = lowres_clock::now(); + for (;;) { + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + + // if the service is unavailable or 400, retry, seems http server is not ready yet + auto* const unavailable = std::get_if(&keys.error()); + auto* const service_error = std::get_if(&keys.error()); + if ((unavailable == nullptr && service_error == nullptr) || + (service_error != nullptr && service_error->status != status_type::bad_request)) { + constexpr auto MAX_WAIT = std::chrono::seconds(5); + BOOST_REQUIRE(lowres_clock::now() - now < MAX_WAIT); + break; + } + } + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing distances in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances1":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing pk1 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk11":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing ck1 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck11":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // wrong size of pk2 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[78],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // wrong size of ck2 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[23]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // correct reply - service should return keys + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(keys); + BOOST_REQUIRE_EQUAL(keys->size(), 2); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).partition.key().explode()), "[05, 07]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).clustering.explode()), "[09, 02]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).partition.key().explode()), "[06, 08]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).clustering.explode()), "[01, 03]"); + }, + cfg); + co_await server->stop(); +} +