diff --git a/conf/scylla.yaml b/conf/scylla.yaml index fd6c9e9bd9..00bbd53f7f 100644 --- a/conf/scylla.yaml +++ b/conf/scylla.yaml @@ -855,3 +855,10 @@ rf_rack_valid_keyspaces: false # Maximum number of items in single BatchWriteItem command. Default is 100. # Note: DynamoDB has a hard-coded limit of 25. # alternator_max_items_in_batch_write: 100 + +# +# Vector Store options +# +# Uri for the vector store using dns name. Only http schema is supported. Port number is mandatory. +# Default is empty, which means that the vector store is not used. +# vector_store_uri: http://vector-store.dns.name:{port} diff --git a/configure.py b/configure.py index 04fabc1159..0a916d47ee 100755 --- a/configure.py +++ b/configure.py @@ -567,6 +567,7 @@ scylla_tests = set([ 'test/boost/symmetric_key_test', 'test/boost/types_test', 'test/boost/utf8_test', + 'test/boost/vector_store_client_test', 'test/boost/vint_serialization_test', 'test/boost/virtual_table_mutation_source_test', 'test/boost/wasm_alloc_test', @@ -1219,6 +1220,7 @@ scylla_core = (['message/messaging_service.cc', 'node_ops/task_manager_module.cc', 'reader_concurrency_semaphore_group.cc', 'utils/disk_space_monitor.cc', + 'service/vector_store_client.cc', ] + [Antlr3Grammar('cql3/Cql.g')] \ + scylla_raft_core ) diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 0a5df800e2..4ede4485f4 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -27,6 +27,7 @@ #include "cql3/untyped_result_set.hh" #include "db/config.hh" #include "data_dictionary/data_dictionary.hh" +#include "service/vector_store_client.hh" #include "utils/hashers.hh" #include "utils/error_injection.hh" #include "service/migration_manager.hh" @@ -68,11 +69,12 @@ static service::query_state query_state_for_internal_call() { return {service::client_state::for_internal_calls(), empty_service_permit()}; } -query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm) +query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm) : _migration_subscriber{std::make_unique(this)} , _proxy(proxy) , _db(db) , _mnotifier(mn) + , _vector_store_client(vsc) , _mcfg(mcfg) , _cql_config(cql_cfg) , _prepared_cache(prep_cache_log, _mcfg.prepared_statment_cache_size) diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 13a2242184..b17ef86e98 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -28,6 +28,7 @@ #include "transport/messages/result_message.hh" #include "service/client_state.hh" #include "service/broadcast_tables/experimental/query_result.hh" +#include "service/vector_store_client.hh" #include "utils/assert.hh" #include "utils/observable.hh" #include "service/raft/raft_group0_client.hh" @@ -107,6 +108,7 @@ private: service::storage_proxy& _proxy; data_dictionary::database _db; service::migration_notifier& _mnotifier; + service::vector_store_client& _vector_store_client; memory_config _mcfg; const cql_config& _cql_config; @@ -146,7 +148,7 @@ public: static std::unique_ptr parse_statement(const std::string_view& query, dialect d); static std::vector> parse_statements(std::string_view queries, dialect d); - query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); + query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); ~query_processor(); @@ -176,6 +178,14 @@ public: lang::manager& lang() { return _lang_manager; } + const service::vector_store_client& vector_store_client() const noexcept { + return _vector_store_client; + } + + service::vector_store_client& vector_store_client() noexcept { + return _vector_store_client; + } + db::auth_version_t auth_version; statements::prepared_statement::checked_weak_ptr get_prepared(const std::optional& user, const prepared_cache_key_type& key) { diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index dce13d92ae..6ae52c1c97 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -1471,10 +1471,10 @@ indexed_table_select_statement::find_index_partition_ranges(query_processor& qp, // Note: the partitions keys returned by this function are sorted // in token order. See issue #3423. -future, lw_shared_ptr>>> +future, lw_shared_ptr>>> indexed_table_select_statement::find_index_clustering_rows(query_processor& qp, service::query_state& state, const query_options& options) const { - using value_type = std::tuple, lw_shared_ptr>; + using value_type = std::tuple, lw_shared_ptr>; auto now = gc_clock::now(); auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options); const uint64_t limit = get_inner_loop_limit(get_limit(options, _limit), _selection->is_aggregate()); diff --git a/cql3/statements/select_statement.hh b/cql3/statements/select_statement.hh index 9593b1fccd..03c760c1c6 100644 --- a/cql3/statements/select_statement.hh +++ b/cql3/statements/select_statement.hh @@ -43,6 +43,13 @@ namespace restrictions { namespace statements { + +/// Encapsulates a partition key and clustering key prefix as a primary key. +struct primary_key { + dht::decorated_key partition; + clustering_key_prefix clustering; +}; + /** * Encapsulates a completely parsed SELECT query, including the target * column family, expression, result count, and ordering clause. @@ -134,12 +141,6 @@ public: const query_options& options, gc_clock::time_point now, int32_t page_size, bool aggregate, bool nonpaged_filtering, uint64_t limit, std::optional cas_shard) const; - - struct primary_key { - dht::decorated_key partition; - clustering_key_prefix clustering; - }; - future> process_results(foreign_ptr> results, lw_shared_ptr cmd, const query_options& options, gc_clock::time_point now) const; diff --git a/db/config.cc b/db/config.cc index 72d62c3eca..3dc7e7049b 100644 --- a/db/config.cc +++ b/db/config.cc @@ -1361,6 +1361,7 @@ db::config::config(std::shared_ptr exts) // alternator_max_items_in_batch_write matches DynamoDB behaviour of size limit, but with different value - for DynamoDB it's 25 // (see DynamoDB's documentation for BatchWriteItem command) , alternator_max_items_in_batch_write(this, "alternator_max_items_in_batch_write", value_status::Used, 100, "Maximum amount of items in single BatchItemWrite call.") + , vector_store_uri(this, "vector_store_uri", value_status::Used, "", "The URI of the vector store to use for vector search. If not set, vector search is disabled.") , abort_on_ebadf(this, "abort_on_ebadf", value_status::Used, true, "Abort the server on incorrect file descriptor access. Throws exception when disabled.") , redis_port(this, "redis_port", value_status::Used, 0, "Port on which the REDIS transport listens for clients.") , redis_ssl_port(this, "redis_ssl_port", value_status::Used, 0, "Port on which the REDIS TLS native transport listens for clients.") diff --git a/db/config.hh b/db/config.hh index ff97a2f28c..1cc1c0415d 100644 --- a/db/config.hh +++ b/db/config.hh @@ -488,6 +488,8 @@ public: named_value alternator_describe_endpoints; named_value alternator_max_items_in_batch_write; + named_value vector_store_uri; + named_value abort_on_ebadf; named_value redis_port; diff --git a/docs/dev/protocols.md b/docs/dev/protocols.md index ce413b2566..205aef2f5f 100644 --- a/docs/dev/protocols.md +++ b/docs/dev/protocols.md @@ -276,3 +276,48 @@ to Scylla's REST API port over the loopback (localhost) interface. The port on which scylla-jmx listens is by default port 7199. This port, and the listen address, can be overridden with the `-jp` and `-ja` options (respectively) of the `scylla-jmx` script. + +## Vector Store + +The Vector Search functionality within Scylla is provided by an external +service [vector-store](https://github.com/scylla/vector-store). That service is +responsible for creating and managing vector indexes built from data retrieved +from Scylla using the CQL protocol and CDC functionality from a table and a +custom index created in Scylla using `CREATE INDEX {keyspace}.{index} ON +{table}({embedding) USING 'vector_index'`. Scylla does a vector search by +delegating the search to the vector-store service using the HTTP API protocol. +Scylla is the HTTP client of the vector-store service. + +The supported vector-store HTTP API: + +### `POST /api/v1/indexes/{keyspace}/{index}/ann` + +This endpoint is for an ANN (Approximate Nearest Neighbor) search. + +Parameters: +- `keyspace`: The keyspace name of the index to search. +- `index`: The index name to search. + +Request Body: +```json +{ + "embedding": [0.1, 0.2, 0.3, ...], // The vector to search for. + "limit": 10 // The number of nearest neighbors to return. +} +``` + +Responses: + +- 200 OK: Returns the nearest neighbors found. + + Response Body (Structure of Arrays): + ```json + { + "distances": [0.1234, 0.5678, ...], // The distances to the nearest neighbors up to the limit provided by a request. + "primary_keys": { + "pk1_column_name": ["value1", "value2", ...], // The primary key values of the pk1_column_name with same size as "distances". + "ck1_column_name": ["value1", "value2", ...], // The primary key values of the ck1_column_name with same size as "distances". + } + } + ``` + diff --git a/main.cc b/main.cc index 19dae2099a..929cc2075d 100644 --- a/main.cc +++ b/main.cc @@ -40,6 +40,7 @@ #include "service/migration_manager.hh" #include "service/tablet_allocator.hh" #include "service/load_meter.hh" +#include "service/vector_store_client.hh" #include "service/view_update_backlog_broker.hh" #include "service/qos/service_level_controller.hh" #include "streaming/stream_session.hh" @@ -740,6 +741,7 @@ sharded token_metadata; sharded mapreduce_service; sharded gossiper; sharded snitch; + sharded vector_store_client; // This worker wasn't designed to be used from multiple threads. // If you are attempting to do that, make sure you know what you are doing. @@ -779,7 +781,8 @@ sharded token_metadata; return seastar::async([&app, cfg, ext, &disk_space_monitor_shard0, &cm, &sstm, &db, &qp, &bm, &proxy, &mapreduce_service, &mm, &mm_notifier, &ctx, &opts, &dirs, &prometheus_server, &cf_cache_hitrate_calculator, &load_meter, &feature_service, &gossiper, &snitch, &token_metadata, &erm_factory, &snapshot_ctl, &messaging, &sst_dir_semaphore, &raft_gr, &service_memory_limiter, - &repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker] { + &repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker, + &vector_store_client] { try { if (opts.contains("relabel-config-file") && !opts["relabel-config-file"].as().empty()) { // calling update_relabel_config_from_file can cause an exception that would stop startup @@ -1318,6 +1321,13 @@ sharded token_metadata; static sharded cql_config; cql_config.start(std::ref(*cfg)).get(); + checkpoint(stop_signal, "starting a vector store service"); + vector_store_client.start(std::ref(*cfg)).get(); + auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [&vector_store_client] { + vector_store_client.stop().get(); + }); + vector_store_client.invoke_on_all(&service::vector_store_client::start_background_tasks).get(); + checkpoint(stop_signal, "starting query processor"); cql3::query_processor::memory_config qp_mcfg = {memory::stats().total_memory() / 256, memory::stats().total_memory() / 2560}; debug::the_query_processor = &qp; @@ -1329,7 +1339,7 @@ sharded token_metadata; std::chrono::duration_cast(cql3::prepared_statements_cache::entry_expiry)); auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms()); - qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get(); + qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), std::ref(vector_store_client), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get(); checkpoint(stop_signal, "starting lifecycle notifier"); lifecycle_notifier.start().get(); diff --git a/service/CMakeLists.txt b/service/CMakeLists.txt index 4c672cb3d9..1e1ee39580 100644 --- a/service/CMakeLists.txt +++ b/service/CMakeLists.txt @@ -33,7 +33,8 @@ target_sources(service task_manager_module.cc topology_coordinator.cc topology_mutation.cc - topology_state_machine.cc) + topology_state_machine.cc + vector_store_client.cc) target_include_directories(service PUBLIC ${CMAKE_SOURCE_DIR}) diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc new file mode 100644 index 0000000000..63d930ce05 --- /dev/null +++ b/service/vector_store_client.cc @@ -0,0 +1,572 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "vector_store_client.hh" +#include "cql3/statements/select_statement.hh" +#include "cql3/type_json.hh" +#include "db/config.hh" +#include "exceptions/exceptions.hh" +#include "utils/sequential_producer.hh" +#include "dht/i_partitioner.hh" +#include "keys.hh" +#include "utils/rjson.hh" +#include "schema/schema.hh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +using ann_error = service::vector_store_client::ann_error; +using configuration_exception = exceptions::configuration_exception; +using duration = lowres_clock::duration; +using embedding = service::vector_store_client::embedding; +using limit = service::vector_store_client::limit; +using host_name = service::vector_store_client::host_name; +using http_client = http::experimental::client; +using http_path = sstring; +using inet_address = seastar::net::inet_address; +using json_content = sstring; +using milliseconds = std::chrono::milliseconds; +using operation_type = httpd::operation_type; +using port_number = service::vector_store_client::port_number; +using primary_key = service::vector_store_client::primary_key; +using primary_keys = service::vector_store_client::primary_keys; +using service_reply_format_error = service::vector_store_client::service_reply_format_error; +using time_point = lowres_clock::time_point; + +// Wait time before retrying after an exception occurred +constexpr auto EXCEPTION_OCCURED_WAIT = std::chrono::seconds(5); + +// Minimum interval between dns name refreshes +constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5); + +/// Timeout for waiting for a new client to be available +constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5); + +/// How many retries to do for HTTP requests +constexpr auto HTTP_REQUEST_RETRIES = 3; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +logging::logger vslogger("vector_store_client"); + +auto parse_port(std::string const& port_txt) -> std::optional { + auto port = port_number{}; + auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port); + if (*ptr != '\0' || ec != std::errc{}) { + return std::nullopt; + } + return port; +} + +auto parse_service_uri(std::string_view uri) -> std::optional> { + constexpr auto URI_REGEX = R"(^http:\/\/([a-z0-9._-]+):([0-9]+)$)"; + auto const uri_regex = std::regex(URI_REGEX); + auto uri_match = std::smatch{}; + auto uri_txt = std::string(uri); + if (!std::regex_match(uri_txt, uri_match, uri_regex) || uri_match.size() != 3) { + return {}; + } + auto host = uri_match[1].str(); + auto port = parse_port(uri_match[2].str()); + if (!port) { + return {}; + } + return {{host, *port}}; +} + +/// Wait for a timeout ar abort signal. +auto wait_for_timeout(duration timeout, abort_source& as) -> future { + auto result = co_await coroutine::as_future(sleep_abortable(timeout, as)); + if (result.failed()) { + auto err = result.get_exception(); + if (as.abort_requested()) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + +/// Wait for a condition variable to be signaled or timeout. +auto wait_for_signal(condition_variable& cv, time_point timeout) -> future { + auto result = co_await coroutine::as_future(cv.wait(timeout)); + if (result.failed()) { + auto err = result.get_exception(); + if (try_catch(err) != nullptr) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + +auto get_key_column_value(const rjson::value& item, std::size_t idx, const column_definition& column) -> std::expected { + auto const& column_name = column.name_as_text(); + auto const* keys_obj = rjson::find(item, column_name); + if (keys_obj == nullptr) { + vslogger.error("Vector Store returned invalid JSON: missing key column '{}'", column_name); + return std::unexpected{service_reply_format_error{}}; + } + if (!keys_obj->IsArray()) { + vslogger.error("Vector Store returned invalid JSON: key column '{}' is not an array", column_name); + return std::unexpected{service_reply_format_error{}}; + } + auto const& keys_arr = keys_obj->GetArray(); + if (keys_arr.Size() <= idx) { + vslogger.error("Vector Store returned invalid JSON: key column '{}' array too small", column_name); + return std::unexpected{service_reply_format_error{}}; + } + auto const& key = keys_arr[idx]; + return from_json_object(*column.type, key); +} + +auto pk_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected { + std::vector raw_pk; + for (const column_definition& cdef : schema->partition_key_columns()) { + auto raw_value = get_key_column_value(item, idx, cdef); + if (!raw_value) { + return std::unexpected{raw_value.error()}; + } + raw_pk.emplace_back(*raw_value); + } + return partition_key::from_exploded(raw_pk); +} + +auto ck_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected { + if (schema->clustering_key_size() == 0) { + return clustering_key_prefix::make_empty(); + } + + std::vector raw_ck; + for (const column_definition& cdef : schema->clustering_key_columns()) { + auto raw_value = get_key_column_value(item, idx, cdef); + if (!raw_value) { + return std::unexpected{raw_value.error()}; + } + raw_ck.emplace_back(*raw_value); + } + + return clustering_key_prefix::from_exploded(raw_ck); +} + +auto write_ann_json(embedding embedding, limit limit) -> json_content { + return seastar::format(R"({{"embedding":[{}],"limit":{}}})", fmt::join(embedding, ","), limit); +} + +auto read_ann_json(rjson::value const& json, schema_ptr const& schema) -> std::expected { + if (!json.HasMember("primary_keys")) { + vslogger.error("Vector Store returned invalid JSON: missing 'primary_keys'"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& keys_json = json["primary_keys"]; + if (!keys_json.IsObject()) { + vslogger.error("Vector Store returned invalid JSON: 'primary_keys' is not an object"); + return std::unexpected{service_reply_format_error{}}; + } + + if (!json.HasMember("distances")) { + vslogger.error("Vector Store returned invalid JSON: missing 'distances'"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& distances_json = json["distances"]; + if (!distances_json.IsArray()) { + vslogger.error("Vector Store returned invalid JSON: 'distances' is not an array"); + return std::unexpected{service_reply_format_error{}}; + } + auto const& distances_arr = json["distances"].GetArray(); + + auto size = distances_arr.Size(); + auto keys = primary_keys{}; + for (auto idx = 0U; idx < size; ++idx) { + auto pk = pk_from_json(keys_json, idx, schema); + if (!pk) { + return std::unexpected{pk.error()}; + } + auto ck = ck_from_json(keys_json, idx, schema); + if (!ck) { + return std::unexpected{ck.error()}; + } + keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck}); + } + return std::move(keys); +} + +} // namespace + +namespace service { + +struct vector_store_client::impl { + lw_shared_ptr current_client; + std::vector> old_clients; + host_name host; + port_number port{}; + inet_address addr; + time_point last_dns_refresh; + gate tasks_gate; + condition_variable refresh_cv; + condition_variable refresh_client_cv; + abort_source abort_refresh; + milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL; + milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT; + unsigned http_request_retries = HTTP_REQUEST_RETRIES; + std::function>(sstring const&)> dns_resolver; + sequential_producer> client_producer; + + impl(host_name host_, port_number port_) + : host(std::move(host_)) + , port(port_) + , dns_resolver([](auto const& host) -> future> { + auto addr = co_await coroutine::as_future(net::dns::resolve_name(host)); + if (addr.failed()) { + auto err = addr.get_exception(); + if (try_catch(err) != nullptr) { + co_return std::nullopt; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return co_await std::move(addr); + }) + , client_producer([&]() -> future> { + trigger_dns_refresh(); + co_await wait_for_signal(refresh_client_cv, lowres_clock::now() + wait_for_client_timeout); + co_return current_client; + }) { + } + + /// Refresh the http client with a new address resolved from the DNS name. + /// If the DNS resolution fails, the current client is set to nullptr. + /// If the address is the same as the current one, do nothing. + /// Old clients are saved for later cleanup in a specific task. + auto refresh_addr() -> future<> { + auto new_addr = co_await dns_resolver(host); + if (!new_addr) { + current_client = nullptr; + co_return; + } + + // Check if the new address is the same as the current one + if (current_client && *new_addr == addr) { + co_return; + } + + addr = *new_addr; + old_clients.emplace_back(current_client); + current_client = make_lw_shared(socket_address(addr, port)); + } + + /// A task for refreshing the vector store http client. + auto refresh_addr_task() -> future<> { + for (;;) { + auto exception_occured = false; + try { + if (abort_refresh.abort_requested()) { + break; + } + + // Do not refresh the service address too often + auto now = lowres_clock::now(); + auto current_duration = now - last_dns_refresh; + if (current_duration > dns_refresh_interval) { + last_dns_refresh = now; + co_await refresh_addr(); + } else { + // Wait till the end of the refreshing interval + if (co_await wait_for_timeout(dns_refresh_interval - current_duration, abort_refresh)) { + continue; + } + // If the wait was aborted, we stop refreshing + break; + } + + if (abort_refresh.abort_requested()) { + break; + } + + // new client is available + refresh_client_cv.broadcast(); + + co_await cleanup_old_clients(); + + co_await refresh_cv.when(); + } catch (const std::exception& e) { + vslogger.error("Vector Store Client refresh task failed: {}", e.what()); + exception_occured = true; + } catch (...) { + vslogger.error("Vector Store Client refresh task failed with unknown exception"); + exception_occured = true; + } + if (exception_occured) { + // If an exception occurred, we wait for the next signal to refresh the address + co_await wait_for_timeout(EXCEPTION_OCCURED_WAIT, abort_refresh); + } + } + + co_await cleanup_old_clients(); + co_await cleanup_current_client(); + } + + /// Request a DNS refresh in the specific task. + void trigger_dns_refresh() { + refresh_cv.signal(); + } + + /// Cleanup current client + auto cleanup_current_client() -> future<> { + if (current_client) { + co_await current_client->close(); + } + current_client = nullptr; + } + + /// Cleanup old clients that are no longer used. + auto cleanup_old_clients() -> future<> { + // iterate over old clients and close them. There is a co_await in the loop + // so we need to use [] accessor and copying clients to avoid dangling references of iterators. + // NOLINTNEXTLINE(modernize-loop-convert) + for (auto it = 0U; it < old_clients.size(); ++it) { + auto& client = old_clients[it]; + if (client && client.owned()) { + auto client_cloned = client; + co_await client_cloned->close(); + client_cloned = nullptr; + } + } + std::erase_if(old_clients, [](auto const& client) { + return !client; + }); + } + + struct get_client_response { + lw_shared_ptr client; ///< The http client. + host_name host; ///< The host name for the vector-store service. + }; + + using get_client_error = std::variant; + + /// Get the current http client or wait for a new one to be available. + auto get_client(abort_source& as) -> future> { + if (current_client) { + co_return get_client_response{.client = current_client, .host = host}; + } + + auto current_client = co_await coroutine::as_future(client_producer(as)); + + if (current_client.failed()) { + auto err = current_client.get_exception(); + if (as.abort_requested()) { + co_return std::unexpected{aborted{}}; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + auto client = co_await std::move(current_client); + if (!client) { + co_return std::unexpected{addr_unavailable{}}; + } + co_return get_client_response{.client = client, .host = host}; + } + + struct make_request_response { + http::reply::status_type status; ///< The HTTP status of the response. + std::vector> content; ///< The content of the response. + }; + + using make_request_error = std::variant; + + auto make_request(operation_type method, http_path path, std::optional content, abort_source& as) + -> future> { + auto resp = make_request_response{.status = http::reply::status_type::ok, .content = std::vector>()}; + + for (auto retries = 0; retries < HTTP_REQUEST_RETRIES; ++retries) { + auto client_host = co_await get_client(as); + if (!client_host) { + co_return std::unexpected{std::visit( + [](auto&& err) { + return make_request_error{err}; + }, + client_host.error())}; + } + auto [client, host] = *std::move(client_host); + + auto req = http::request::make(method, host, path); + if (content) { + req.write_body("json", *content); + } + + auto result = co_await coroutine::as_future(client->make_request( + std::move(req), + [&resp](http::reply const& reply, input_stream body) -> future<> { + resp.status = reply._status; + resp.content = co_await util::read_entire_stream(body); + }, + std::nullopt, &as)); + if (result.failed()) { + auto err = result.get_exception(); + if (as.abort_requested()) { + co_return std::unexpected{aborted{}}; + } + if (try_catch(err) == nullptr) { + co_await coroutine::return_exception_ptr(std::move(err)); + } + // std::system_error means that the server is unavailable, so we retry + } else { + co_return resp; + } + + trigger_dns_refresh(); + } + + co_return std::unexpected{service_unavailable{}}; + } +}; + +vector_store_client::vector_store_client(config const& cfg) { + auto config_uri = cfg.vector_store_uri(); + if (config_uri.empty()) { + vslogger.info("Vector Store service URI is not configured."); + return; + } + + auto parsed_uri = parse_service_uri(config_uri); + if (!parsed_uri) { + throw configuration_exception(format("Invalid Vector Store service URI: {}", config_uri)); + } + + auto [host, port] = *parsed_uri; + _impl = std::make_unique(std::move(host), port); + vslogger.info("Vector Store service uri = {}:{}.", _impl->host, _impl->port); +} + +vector_store_client::~vector_store_client() = default; + +void vector_store_client::start_background_tasks() { + if (is_disabled()) { + return; + } + + /// start the background task to refresh the service address + (void)try_with_gate(_impl->tasks_gate, [this] { + return _impl->refresh_addr_task(); + }).handle_exception([](std::exception_ptr eptr) { + on_internal_error_noexcept(vslogger, format("The Vector Store Client refresh task failed: {}", eptr)); + }); +} + +auto vector_store_client::stop() -> future<> { + if (is_disabled()) { + co_return; + } + + _impl->abort_refresh.request_abort(); + _impl->refresh_cv.signal(); + co_await _impl->tasks_gate.close(); +} + +auto vector_store_client::host() const -> std::expected { + if (is_disabled()) { + return std::unexpected{disabled{}}; + } + return {_impl->host}; +} + +auto vector_store_client::port() const -> std::expected { + if (is_disabled()) { + return std::unexpected{disabled{}}; + } + return {_impl->port}; +} + +auto vector_store_client::ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as) + -> future> { + if (is_disabled()) { + vslogger.error("Disabled Vector Store while calling ann"); + co_return std::unexpected{disabled{}}; + } + + auto path = format("/api/v1/indexes/{}/{}/ann", keyspace, name); + auto content = write_ann_json(std::move(embedding), limit); + + auto resp = co_await _impl->make_request(operation_type::POST, std::move(path), std::move(content), as); + if (!resp) { + co_return std::unexpected{std::visit( + [](auto&& err) { + return ann_error{err}; + }, + resp.error())}; + } + + if (resp->status != status_type::ok) { + vslogger.error("Vector Store returned error: HTTP status {}: {}", resp->status, resp->content); + co_return std::unexpected{service_error{resp->status}}; + } + + try { + co_return read_ann_json(rjson::parse(std::move(resp->content)), schema); + } catch (const rjson::error& e) { + vslogger.error("Vector Store returned invalid JSON: {}", e.what()); + co_return std::unexpected{service_reply_format_error{}}; + } +} + +void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client"); + } + vsc._impl->dns_refresh_interval = interval; +} + +void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set wait_for_client_timeout on a disabled vector store client"); + } + vsc._impl->wait_for_client_timeout = timeout; +} + +void vector_store_client_tester::set_http_request_retries(vector_store_client& vsc, unsigned retries) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set http_request_retries on a disabled vector store client"); + } + vsc._impl->http_request_retries = retries; +} + +void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client"); + } + vsc._impl->dns_resolver = std::move(resolver); +} + +void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot trigger a dns resolver on a disabled vector store client"); + } + vsc._impl->trigger_dns_refresh(); +} + +auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future> { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot check hostname resolving on a disabled vector store client"); + } + auto client_host = co_await vsc._impl->get_client(as); + if (!client_host) { + co_return std::nullopt; + } + co_return vsc._impl->addr; +} + +} // namespace service + diff --git a/service/vector_store_client.hh b/service/vector_store_client.hh new file mode 100644 index 0000000000..ac6d26da77 --- /dev/null +++ b/service/vector_store_client.hh @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "seastarx.hh" +#include +#include +#include +#include + +class schema; + +namespace cql3::statements { +class primary_key; +} + +namespace db { +class config; +} + +namespace seastar::net { +class inet_address; +} + +namespace service { + +/// A client with the vector-store service. +class vector_store_client final { + struct impl; + std::unique_ptr _impl; + +public: + using config = db::config; + using embedding = std::vector; + using host_name = sstring; + using index_name = sstring; + using keyspace_name = sstring; + using limit = std::size_t; + using port_number = std::uint16_t; + using primary_key = cql3::statements::primary_key; + using primary_keys = std::vector; + using schema_ptr = lw_shared_ptr; + using status_type = http::reply::status_type; + + /// The vector_store_client service is disabled. + struct disabled {}; + + /// The operation was aborted. + struct aborted {}; + + /// The vector-store addr is unavailable (not possible to get an addr from the dns service). + struct addr_unavailable {}; + + /// The vector-store service is unavailable. + struct service_unavailable {}; + + /// The error from the vector-store service. + struct service_error { + status_type status; ///< The HTTP status code from the vector-store service. + }; + + /// An unsupported reply format from the vector-store service. + struct service_reply_format_error {}; + + using ann_error = std::variant; + + explicit vector_store_client(config const& cfg); + ~vector_store_client(); + + /// Start background tasks. + void start_background_tasks(); + + /// Stop the service. + auto stop() -> future<>; + + /// Check if the vector_store_client is disabled. + auto is_disabled() const { + return !bool{_impl}; + } + + /// Get the current host name. + [[nodiscard]] auto host() const -> std::expected; + + /// Get the current port number. + [[nodiscard]] auto port() const -> std::expected; + + /// Request the vector store service for the primary keys of the nearest neighbors + auto ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as) + -> future>; + +private: + friend struct vector_store_client_tester; +}; + +/// A tester for the vector_store_client, used for testing purposes. +struct vector_store_client_tester { + static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval); + static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout); + static void set_http_request_retries(vector_store_client& vsc, unsigned retries); + static void set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver); + static void trigger_dns_resolver(vector_store_client& vsc); + static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future>; +}; + +} // namespace service + diff --git a/test/boost/CMakeLists.txt b/test/boost/CMakeLists.txt index 1fb8106a07..06bcb078f3 100644 --- a/test/boost/CMakeLists.txt +++ b/test/boost/CMakeLists.txt @@ -286,6 +286,8 @@ add_scylla_test(wrapping_interval_test KIND BOOST) add_scylla_test(address_map_test KIND SEASTAR) +add_scylla_test(vector_store_client_test + KIND SEASTAR) add_scylla_test(combined_tests KIND SEASTAR diff --git a/test/boost/vector_store_client_test.cc b/test/boost/vector_store_client_test.cc new file mode 100644 index 0000000000..ee15e4bbb7 --- /dev/null +++ b/test/boost/vector_store_client_test.cc @@ -0,0 +1,540 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "service/vector_store_client.hh" +#include "db/config.hh" +#include "exceptions/exceptions.hh" +#include "cql3/statements/select_statement.hh" +#include "test/lib/cql_test_env.hh" +#include "test/lib/log.hh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace { + +using namespace seastar; + +using vector_store_client = service::vector_store_client; +using vector_store_client_tester = service::vector_store_client_tester; +using config = vector_store_client::config; +using configuration_exception = exceptions::configuration_exception; +using inet_address = seastar::net::inet_address; +using function_handler = httpd::function_handler; +using http_server = httpd::http_server; +using http_server_tester = httpd::http_server_tester; +using milliseconds = std::chrono::milliseconds; +using operation_type = httpd::operation_type; +using port_number = vector_store_client::port_number; +using reply = http::reply; +using request = http::request; +using routes = httpd::routes; +using status_type = http::reply::status_type; +using url = httpd::url; + +constexpr auto const* LOCALHOST = "127.0.0.1"; + +/// Generate an ephemeral port number for listening on localhost. +/// After closing this socket, the port should be not listened on for a while. +/// This is not guaranteed to be a robust solution, but it should work for most tests. +auto generate_unavailable_localhost_port() -> port_number { + auto inaddr = net::inet_address(LOCALHOST); + auto server = listen(socket_address(inaddr, 0)); + auto port = server.local_address().port(); + server.abort_accept(); + return port; +} + +auto listen_on_ephemeral_port(std::unique_ptr server) -> future, socket_address>> { + auto inaddr = net::inet_address(LOCALHOST); + auto const addr = socket_address(inaddr, 0); + co_await server->listen(addr); + auto const& listeners = http_server_tester::listeners(*server); + BOOST_CHECK_EQUAL(listeners.size(), 1); + co_return std::make_tuple(std::move(server), listeners[0].local_address().port()); +} + +auto new_http_server(std::function set_routes) -> future, socket_address>> { + auto server = std::make_unique("test_vector_store_client"); + set_routes(server->_routes); + server->set_content_streaming(true); + co_return co_await listen_on_ephemeral_port(std::move(server)); +} + +auto repeat_until(milliseconds timeout, std::function()> func) -> future { + auto begin = lowres_clock::now(); + while (!co_await func()) { + if (lowres_clock::now() - begin > timeout) { + co_return false; + } + } + co_return true; +} + +auto print_addr(const inet_address& addr) -> sstring { + return format("{}", addr); +} + +} // namespace + +BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) { + { + auto cfg = config(); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(vs.is_disabled()); + BOOST_CHECK(!vs.host()); + BOOST_CHECK(!vs.port()); + } + { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.com:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + BOOST_CHECK_EQUAL(*vs.host(), "good.authority.com"); + BOOST_CHECK_EQUAL(*vs.port(), 6080); + } + { + auto cfg = config(); + cfg.vector_store_uri.set("http://bad,authority.com:6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("bad-schema://authority.com:6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.port.com:a6080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.port.com:60806080"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://bad.format.com:60:80"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + cfg.vector_store_uri.set("http://authority.com:6080/bad/path"); + BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception); + } +} + +/// Resolving of the hostname is started in start_background_tasks() +SEASTAR_TEST_CASE(vector_store_client_test_dns_started) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + + co_await vs.stop(); +} + +/// Unable to resolve the hostname +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolve_failure) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return std::nullopt; + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(!co_await vector_store_client_tester::resolve_hostname(vs, as)); + + co_await vs.stop(); +} + +/// Resolving of the hostname is repeated after errors +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolving_repeated) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(20)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + if (count % 3 != 0) { + co_return std::nullopt; + } + co_return inet_address(format("127.0.0.{}", count)); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 3); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.3"); + + vector_store_client_tester::trigger_dns_resolver(vs); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return !co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 6); + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.6"); + + co_await vs.stop(); +} + +/// Minimal interval between DNS refreshes is respected +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_respects_interval) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + co_await sleep(std::chrono::milliseconds(20)); // wait for the first DNS refresh + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_EQUAL(count, 1); + count = 0; + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + co_await sleep(std::chrono::milliseconds(100)); // wait for the next DNS refresh + + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_GE(count, 1); + BOOST_CHECK_LE(count, 2); + + co_await vs.stop(); +} + +/// DNS refresh could be aborted +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_await sleep(std::chrono::milliseconds(100)); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto timeout = timer([&as]() { + as.request_abort(); + }); + timeout.arm(std::chrono::milliseconds(10)); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_CHECK(!addr); + + co_await vs.stop(); +} + +SEASTAR_TEST_CASE(vector_store_client_ann_test_disabled) { + co_await do_with_cql_env([](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_addr_unavailable) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set("http://bad.authority.here:6080"); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "bad.authority.here"); + co_return std::nullopt; + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_service_unavailable) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port())); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + +SEASTAR_TEST_CASE(vector_store_client_test_ann_service_aborted) { + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port())); + co_await do_with_cql_env( + [](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_await sleep(std::chrono::milliseconds(100)); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto timeout = timer([&as]() { + as.request_abort(); + }); + timeout.arm(std::chrono::milliseconds(10)); + auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + }, + cfg); +} + + +SEASTAR_TEST_CASE(vector_store_client_test_ann_request) { + auto ann_replies = make_lw_shared>>(); + auto [server, addr] = co_await new_http_server([ann_replies](routes& r) { + auto ann = [ann_replies](std::unique_ptr req, std::unique_ptr rep) -> future> { + BOOST_REQUIRE(!ann_replies->empty()); + auto [req_exp, rep_inp] = ann_replies->front(); + auto const req_inp = co_await util::read_entire_stream_contiguous(*req->content_stream); + BOOST_CHECK_EQUAL(req_inp, req_exp); + ann_replies->pop(); + rep->set_status(status_type::ok); + rep->write_body("json", rep_inp); + co_return rep; + }; + r.add(operation_type::POST, url("/api/v1/indexes/ks/idx").remainder("ann"), new function_handler(ann, "json")); + }); + + auto cfg = cql_test_config(); + cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", addr.port())); + co_await do_with_cql_env( + [&ann_replies](cql_test_env& env) -> future<> { + co_await env.execute_cql(R"( + create table ks.vs ( + pk1 tinyint, pk2 tinyint, + ck1 tinyint, ck2 tinyint, + embedding vector, + primary key ((pk1, pk2), ck1, ck2)) + )"); + + auto schema = env.local_db().find_schema("ks", "vs"); + auto& vs = env.local_qp().vector_store_client(); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_http_request_retries(vs, 3); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + // set the wrong idx (wrong endpoint) - service should return 404 + auto as = abort_source(); + auto keys = co_await vs.ann("ks", "idx2", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + auto* err = std::get_if(&keys.error()); + BOOST_CHECK(err != nullptr); + BOOST_CHECK_EQUAL(err->status, status_type::not_found); + + // missing primary_keys in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys1":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + auto const now = lowres_clock::now(); + for (;;) { + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + + // if the service is unavailable or 400, retry, seems http server is not ready yet + auto* const unavailable = std::get_if(&keys.error()); + auto* const service_error = std::get_if(&keys.error()); + if ((unavailable == nullptr && service_error == nullptr) || + (service_error != nullptr && service_error->status != status_type::bad_request)) { + constexpr auto MAX_WAIT = std::chrono::seconds(5); + BOOST_REQUIRE(lowres_clock::now() - now < MAX_WAIT); + break; + } + } + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing distances in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances1":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing pk1 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk11":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // missing ck1 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck11":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // wrong size of pk2 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[78],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // wrong size of ck2 key in the reply - service should return format error + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[23]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(!keys); + BOOST_CHECK(std::holds_alternative(keys.error())); + + // correct reply - service should return keys + ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})", + R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})")); + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as); + BOOST_REQUIRE(keys); + BOOST_REQUIRE_EQUAL(keys->size(), 2); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).partition.key().explode()), "[05, 07]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).clustering.explode()), "[09, 02]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).partition.key().explode()), "[06, 08]"); + BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).clustering.explode()), "[01, 03]"); + }, + cfg); + co_await server->stop(); +} + diff --git a/test/lib/cql_test_env.cc b/test/lib/cql_test_env.cc index bd4bccda23..b12a5caa51 100644 --- a/test/lib/cql_test_env.cc +++ b/test/lib/cql_test_env.cc @@ -170,6 +170,7 @@ private: sharded _gossip_address_map; sharded _fd_pinger; sharded _cdc; + sharded _vector_store_client; db::config* _db_config; service::raft_group0_client* _group0_client; @@ -704,7 +705,12 @@ private: std::chrono::duration_cast(cql3::prepared_statements_cache::entry_expiry)); auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms()); - _qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get(); + _vector_store_client.start(std::ref(*cfg)).get(); + auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [this] { + _vector_store_client.stop().get(); + }); + + _qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), std::ref(_vector_store_client), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get(); auto stop_qp = defer_verbose_shutdown("query processor", [this] { _qp.stop().get(); }); _elc_notif.start().get(); diff --git a/utils/sequential_producer.hh b/utils/sequential_producer.hh index a075781510..853129f490 100644 --- a/utils/sequential_producer.hh +++ b/utils/sequential_producer.hh @@ -21,6 +21,7 @@ template class sequential_producer { public: using factory_t = std::function()>; + using time_point = seastar::shared_future::time_point; private: factory_t _factory; @@ -32,11 +33,18 @@ class sequential_producer { clear(); } - seastar::future operator()() { + seastar::future operator()(time_point timeout = time_point::max()) { if (_churning.available()) { _churning = _factory(); } - return _churning.get_future(); + return _churning.get_future(timeout); + } + + seastar::future operator()(seastar::abort_source& as) { + if (_churning.available()) { + _churning = _factory(); + } + return _churning.get_future(as); } void clear() {