Merge 'vector_store_client: implement vector_store_client service' from Pawel Pery

Vector Store service is a http server which provides vector search index and an ANN (Approximate Nearest Neighbor) functionality. Vector Store retrieves metadata & data from Scylla about indexes using CQL protocol & CDC functionality. Scylla will request ann search using http api.

Commits for the patch:
- implement initial `vector_store_client` service. It adds also a parameter `vector_store_uri` to the scylla.
- refactor sequential_producer as abortable
- implement ip addr retrieval from dns. The uri for Vector Store must contains dns name, this commit implements ip addr refreshing functionality
- refactor primary_key as a top-level class. It is needed for the forward declaration of a primary_key
- implement ANN API. It implements a core ANN search request functionality, adds Vector Store HTTP API description in docs/protocols.md, and implements automatic boost tests with mocked http server for checking error conditions.

New feature, should not be backported.

Fixes: VECTOR-47
Fixes: VECTOR-45

-~-

Closes scylladb/scylladb#24331

* github.com:scylladb/scylladb:
  vector_store_client: implement ANN API
  cql3: refactor primary_key as a top-level class
  vector_store_client: implement ip addr retrieval from dns
  utils: refactor sequential_producer as abortable
  vector_store_client: implement initial vector_store_client service
This commit is contained in:
Piotr Dulikowski
2025-07-10 13:18:20 +02:00
17 changed files with 1337 additions and 16 deletions

View File

@@ -855,3 +855,10 @@ rf_rack_valid_keyspaces: false
# Maximum number of items in single BatchWriteItem command. Default is 100.
# Note: DynamoDB has a hard-coded limit of 25.
# alternator_max_items_in_batch_write: 100
#
# Vector Store options
#
# Uri for the vector store using dns name. Only http schema is supported. Port number is mandatory.
# Default is empty, which means that the vector store is not used.
# vector_store_uri: http://vector-store.dns.name:{port}

View File

@@ -567,6 +567,7 @@ scylla_tests = set([
'test/boost/symmetric_key_test',
'test/boost/types_test',
'test/boost/utf8_test',
'test/boost/vector_store_client_test',
'test/boost/vint_serialization_test',
'test/boost/virtual_table_mutation_source_test',
'test/boost/wasm_alloc_test',
@@ -1219,6 +1220,7 @@ scylla_core = (['message/messaging_service.cc',
'node_ops/task_manager_module.cc',
'reader_concurrency_semaphore_group.cc',
'utils/disk_space_monitor.cc',
'service/vector_store_client.cc',
] + [Antlr3Grammar('cql3/Cql.g')] \
+ scylla_raft_core
)

View File

@@ -27,6 +27,7 @@
#include "cql3/untyped_result_set.hh"
#include "db/config.hh"
#include "data_dictionary/data_dictionary.hh"
#include "service/vector_store_client.hh"
#include "utils/hashers.hh"
#include "utils/error_injection.hh"
#include "service/migration_manager.hh"
@@ -68,11 +69,12 @@ static service::query_state query_state_for_internal_call() {
return {service::client_state::for_internal_calls(), empty_service_permit()};
}
query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm)
query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm)
: _migration_subscriber{std::make_unique<migration_subscriber>(this)}
, _proxy(proxy)
, _db(db)
, _mnotifier(mn)
, _vector_store_client(vsc)
, _mcfg(mcfg)
, _cql_config(cql_cfg)
, _prepared_cache(prep_cache_log, _mcfg.prepared_statment_cache_size)

View File

@@ -28,6 +28,7 @@
#include "transport/messages/result_message.hh"
#include "service/client_state.hh"
#include "service/broadcast_tables/experimental/query_result.hh"
#include "service/vector_store_client.hh"
#include "utils/assert.hh"
#include "utils/observable.hh"
#include "service/raft/raft_group0_client.hh"
@@ -107,6 +108,7 @@ private:
service::storage_proxy& _proxy;
data_dictionary::database _db;
service::migration_notifier& _mnotifier;
service::vector_store_client& _vector_store_client;
memory_config _mcfg;
const cql_config& _cql_config;
@@ -146,7 +148,7 @@ public:
static std::unique_ptr<statements::raw::parsed_statement> parse_statement(const std::string_view& query, dialect d);
static std::vector<std::unique_ptr<statements::raw::parsed_statement>> parse_statements(std::string_view queries, dialect d);
query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm);
query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm);
~query_processor();
@@ -176,6 +178,14 @@ public:
lang::manager& lang() { return _lang_manager; }
const service::vector_store_client& vector_store_client() const noexcept {
return _vector_store_client;
}
service::vector_store_client& vector_store_client() noexcept {
return _vector_store_client;
}
db::auth_version_t auth_version;
statements::prepared_statement::checked_weak_ptr get_prepared(const std::optional<auth::authenticated_user>& user, const prepared_cache_key_type& key) {

View File

@@ -1471,10 +1471,10 @@ indexed_table_select_statement::find_index_partition_ranges(query_processor& qp,
// Note: the partitions keys returned by this function are sorted
// in token order. See issue #3423.
future<coordinator_result<std::tuple<std::vector<indexed_table_select_statement::primary_key>, lw_shared_ptr<const service::pager::paging_state>>>>
future<coordinator_result<std::tuple<std::vector<primary_key>, lw_shared_ptr<const service::pager::paging_state>>>>
indexed_table_select_statement::find_index_clustering_rows(query_processor& qp, service::query_state& state, const query_options& options) const
{
using value_type = std::tuple<std::vector<indexed_table_select_statement::primary_key>, lw_shared_ptr<const service::pager::paging_state>>;
using value_type = std::tuple<std::vector<primary_key>, lw_shared_ptr<const service::pager::paging_state>>;
auto now = gc_clock::now();
auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options);
const uint64_t limit = get_inner_loop_limit(get_limit(options, _limit), _selection->is_aggregate());

View File

@@ -43,6 +43,13 @@ namespace restrictions {
namespace statements {
/// Encapsulates a partition key and clustering key prefix as a primary key.
struct primary_key {
dht::decorated_key partition;
clustering_key_prefix clustering;
};
/**
* Encapsulates a completely parsed SELECT query, including the target
* column family, expression, result count, and ordering clause.
@@ -134,12 +141,6 @@ public:
const query_options& options, gc_clock::time_point now, int32_t page_size, bool aggregate, bool nonpaged_filtering, uint64_t limit,
std::optional<service::cas_shard> cas_shard) const;
struct primary_key {
dht::decorated_key partition;
clustering_key_prefix clustering;
};
future<shared_ptr<cql_transport::messages::result_message>> process_results(foreign_ptr<lw_shared_ptr<query::result>> results,
lw_shared_ptr<query::read_command> cmd, const query_options& options, gc_clock::time_point now) const;

View File

@@ -1361,6 +1361,7 @@ db::config::config(std::shared_ptr<db::extensions> exts)
// alternator_max_items_in_batch_write matches DynamoDB behaviour of size limit, but with different value - for DynamoDB it's 25
// (see DynamoDB's documentation for BatchWriteItem command)
, alternator_max_items_in_batch_write(this, "alternator_max_items_in_batch_write", value_status::Used, 100, "Maximum amount of items in single BatchItemWrite call.")
, vector_store_uri(this, "vector_store_uri", value_status::Used, "", "The URI of the vector store to use for vector search. If not set, vector search is disabled.")
, abort_on_ebadf(this, "abort_on_ebadf", value_status::Used, true, "Abort the server on incorrect file descriptor access. Throws exception when disabled.")
, redis_port(this, "redis_port", value_status::Used, 0, "Port on which the REDIS transport listens for clients.")
, redis_ssl_port(this, "redis_ssl_port", value_status::Used, 0, "Port on which the REDIS TLS native transport listens for clients.")

View File

@@ -488,6 +488,8 @@ public:
named_value<sstring> alternator_describe_endpoints;
named_value<uint32_t> alternator_max_items_in_batch_write;
named_value<sstring> vector_store_uri;
named_value<bool> abort_on_ebadf;
named_value<uint16_t> redis_port;

View File

@@ -276,3 +276,48 @@ to Scylla's REST API port over the loopback (localhost) interface.
The port on which scylla-jmx listens is by default port 7199. This port,
and the listen address, can be overridden with the `-jp` and `-ja` options
(respectively) of the `scylla-jmx` script.
## Vector Store
The Vector Search functionality within Scylla is provided by an external
service [vector-store](https://github.com/scylla/vector-store). That service is
responsible for creating and managing vector indexes built from data retrieved
from Scylla using the CQL protocol and CDC functionality from a table and a
custom index created in Scylla using `CREATE INDEX {keyspace}.{index} ON
{table}({embedding) USING 'vector_index'`. Scylla does a vector search by
delegating the search to the vector-store service using the HTTP API protocol.
Scylla is the HTTP client of the vector-store service.
The supported vector-store HTTP API:
### `POST /api/v1/indexes/{keyspace}/{index}/ann`
This endpoint is for an ANN (Approximate Nearest Neighbor) search.
Parameters:
- `keyspace`: The keyspace name of the index to search.
- `index`: The index name to search.
Request Body:
```json
{
"embedding": [0.1, 0.2, 0.3, ...], // The vector to search for.
"limit": 10 // The number of nearest neighbors to return.
}
```
Responses:
- 200 OK: Returns the nearest neighbors found.
Response Body (Structure of Arrays):
```json
{
"distances": [0.1234, 0.5678, ...], // The distances to the nearest neighbors up to the limit provided by a request.
"primary_keys": {
"pk1_column_name": ["value1", "value2", ...], // The primary key values of the pk1_column_name with same size as "distances".
"ck1_column_name": ["value1", "value2", ...], // The primary key values of the ck1_column_name with same size as "distances".
}
}
```

14
main.cc
View File

@@ -40,6 +40,7 @@
#include "service/migration_manager.hh"
#include "service/tablet_allocator.hh"
#include "service/load_meter.hh"
#include "service/vector_store_client.hh"
#include "service/view_update_backlog_broker.hh"
#include "service/qos/service_level_controller.hh"
#include "streaming/stream_session.hh"
@@ -740,6 +741,7 @@ sharded<locator::shared_token_metadata> token_metadata;
sharded<service::mapreduce_service> mapreduce_service;
sharded<gms::gossiper> gossiper;
sharded<locator::snitch_ptr> snitch;
sharded<service::vector_store_client> vector_store_client;
// This worker wasn't designed to be used from multiple threads.
// If you are attempting to do that, make sure you know what you are doing.
@@ -779,7 +781,8 @@ sharded<locator::shared_token_metadata> token_metadata;
return seastar::async([&app, cfg, ext, &disk_space_monitor_shard0, &cm, &sstm, &db, &qp, &bm, &proxy, &mapreduce_service, &mm, &mm_notifier, &ctx, &opts, &dirs,
&prometheus_server, &cf_cache_hitrate_calculator, &load_meter, &feature_service, &gossiper, &snitch,
&token_metadata, &erm_factory, &snapshot_ctl, &messaging, &sst_dir_semaphore, &raft_gr, &service_memory_limiter,
&repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker] {
&repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker,
&vector_store_client] {
try {
if (opts.contains("relabel-config-file") && !opts["relabel-config-file"].as<sstring>().empty()) {
// calling update_relabel_config_from_file can cause an exception that would stop startup
@@ -1318,6 +1321,13 @@ sharded<locator::shared_token_metadata> token_metadata;
static sharded<cql3::cql_config> cql_config;
cql_config.start(std::ref(*cfg)).get();
checkpoint(stop_signal, "starting a vector store service");
vector_store_client.start(std::ref(*cfg)).get();
auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [&vector_store_client] {
vector_store_client.stop().get();
});
vector_store_client.invoke_on_all(&service::vector_store_client::start_background_tasks).get();
checkpoint(stop_signal, "starting query processor");
cql3::query_processor::memory_config qp_mcfg = {memory::stats().total_memory() / 256, memory::stats().total_memory() / 2560};
debug::the_query_processor = &qp;
@@ -1329,7 +1339,7 @@ sharded<locator::shared_token_metadata> token_metadata;
std::chrono::duration_cast<std::chrono::milliseconds>(cql3::prepared_statements_cache::entry_expiry));
auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms());
qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get();
qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), std::ref(vector_store_client), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get();
checkpoint(stop_signal, "starting lifecycle notifier");
lifecycle_notifier.start().get();

View File

@@ -33,7 +33,8 @@ target_sources(service
task_manager_module.cc
topology_coordinator.cc
topology_mutation.cc
topology_state_machine.cc)
topology_state_machine.cc
vector_store_client.cc)
target_include_directories(service
PUBLIC
${CMAKE_SOURCE_DIR})

View File

@@ -0,0 +1,572 @@
/*
* Copyright (C) 2025-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "vector_store_client.hh"
#include "cql3/statements/select_statement.hh"
#include "cql3/type_json.hh"
#include "db/config.hh"
#include "exceptions/exceptions.hh"
#include "utils/sequential_producer.hh"
#include "dht/i_partitioner.hh"
#include "keys.hh"
#include "utils/rjson.hh"
#include "schema/schema.hh"
#include <charconv>
#include <exception>
#include <fmt/ranges.h>
#include <regex>
#include <seastar/coroutine/as_future.hh>
#include <seastar/coroutine/exception.hh>
#include <seastar/http/client.hh>
#include <seastar/http/request.hh>
#include <seastar/net/dns.hh>
#include <seastar/net/inet_address.hh>
#include <seastar/net/socket_defs.hh>
#include <seastar/util/short_streams.hh>
namespace {
using ann_error = service::vector_store_client::ann_error;
using configuration_exception = exceptions::configuration_exception;
using duration = lowres_clock::duration;
using embedding = service::vector_store_client::embedding;
using limit = service::vector_store_client::limit;
using host_name = service::vector_store_client::host_name;
using http_client = http::experimental::client;
using http_path = sstring;
using inet_address = seastar::net::inet_address;
using json_content = sstring;
using milliseconds = std::chrono::milliseconds;
using operation_type = httpd::operation_type;
using port_number = service::vector_store_client::port_number;
using primary_key = service::vector_store_client::primary_key;
using primary_keys = service::vector_store_client::primary_keys;
using service_reply_format_error = service::vector_store_client::service_reply_format_error;
using time_point = lowres_clock::time_point;
// Wait time before retrying after an exception occurred
constexpr auto EXCEPTION_OCCURED_WAIT = std::chrono::seconds(5);
// Minimum interval between dns name refreshes
constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5);
/// Timeout for waiting for a new client to be available
constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5);
/// How many retries to do for HTTP requests
constexpr auto HTTP_REQUEST_RETRIES = 3;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
logging::logger vslogger("vector_store_client");
auto parse_port(std::string const& port_txt) -> std::optional<port_number> {
auto port = port_number{};
auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port);
if (*ptr != '\0' || ec != std::errc{}) {
return std::nullopt;
}
return port;
}
auto parse_service_uri(std::string_view uri) -> std::optional<std::tuple<host_name, port_number>> {
constexpr auto URI_REGEX = R"(^http:\/\/([a-z0-9._-]+):([0-9]+)$)";
auto const uri_regex = std::regex(URI_REGEX);
auto uri_match = std::smatch{};
auto uri_txt = std::string(uri);
if (!std::regex_match(uri_txt, uri_match, uri_regex) || uri_match.size() != 3) {
return {};
}
auto host = uri_match[1].str();
auto port = parse_port(uri_match[2].str());
if (!port) {
return {};
}
return {{host, *port}};
}
/// Wait for a timeout ar abort signal.
auto wait_for_timeout(duration timeout, abort_source& as) -> future<bool> {
auto result = co_await coroutine::as_future(sleep_abortable(timeout, as));
if (result.failed()) {
auto err = result.get_exception();
if (as.abort_requested()) {
co_return false;
}
co_await coroutine::return_exception_ptr(std::move(err));
}
co_return true;
}
/// Wait for a condition variable to be signaled or timeout.
auto wait_for_signal(condition_variable& cv, time_point timeout) -> future<bool> {
auto result = co_await coroutine::as_future(cv.wait(timeout));
if (result.failed()) {
auto err = result.get_exception();
if (try_catch<condition_variable_timed_out>(err) != nullptr) {
co_return false;
}
co_await coroutine::return_exception_ptr(std::move(err));
}
co_return true;
}
auto get_key_column_value(const rjson::value& item, std::size_t idx, const column_definition& column) -> std::expected<bytes, ann_error> {
auto const& column_name = column.name_as_text();
auto const* keys_obj = rjson::find(item, column_name);
if (keys_obj == nullptr) {
vslogger.error("Vector Store returned invalid JSON: missing key column '{}'", column_name);
return std::unexpected{service_reply_format_error{}};
}
if (!keys_obj->IsArray()) {
vslogger.error("Vector Store returned invalid JSON: key column '{}' is not an array", column_name);
return std::unexpected{service_reply_format_error{}};
}
auto const& keys_arr = keys_obj->GetArray();
if (keys_arr.Size() <= idx) {
vslogger.error("Vector Store returned invalid JSON: key column '{}' array too small", column_name);
return std::unexpected{service_reply_format_error{}};
}
auto const& key = keys_arr[idx];
return from_json_object(*column.type, key);
}
auto pk_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<partition_key, ann_error> {
std::vector<bytes> raw_pk;
for (const column_definition& cdef : schema->partition_key_columns()) {
auto raw_value = get_key_column_value(item, idx, cdef);
if (!raw_value) {
return std::unexpected{raw_value.error()};
}
raw_pk.emplace_back(*raw_value);
}
return partition_key::from_exploded(raw_pk);
}
auto ck_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<clustering_key_prefix, ann_error> {
if (schema->clustering_key_size() == 0) {
return clustering_key_prefix::make_empty();
}
std::vector<bytes> raw_ck;
for (const column_definition& cdef : schema->clustering_key_columns()) {
auto raw_value = get_key_column_value(item, idx, cdef);
if (!raw_value) {
return std::unexpected{raw_value.error()};
}
raw_ck.emplace_back(*raw_value);
}
return clustering_key_prefix::from_exploded(raw_ck);
}
auto write_ann_json(embedding embedding, limit limit) -> json_content {
return seastar::format(R"({{"embedding":[{}],"limit":{}}})", fmt::join(embedding, ","), limit);
}
auto read_ann_json(rjson::value const& json, schema_ptr const& schema) -> std::expected<primary_keys, ann_error> {
if (!json.HasMember("primary_keys")) {
vslogger.error("Vector Store returned invalid JSON: missing 'primary_keys'");
return std::unexpected{service_reply_format_error{}};
}
auto const& keys_json = json["primary_keys"];
if (!keys_json.IsObject()) {
vslogger.error("Vector Store returned invalid JSON: 'primary_keys' is not an object");
return std::unexpected{service_reply_format_error{}};
}
if (!json.HasMember("distances")) {
vslogger.error("Vector Store returned invalid JSON: missing 'distances'");
return std::unexpected{service_reply_format_error{}};
}
auto const& distances_json = json["distances"];
if (!distances_json.IsArray()) {
vslogger.error("Vector Store returned invalid JSON: 'distances' is not an array");
return std::unexpected{service_reply_format_error{}};
}
auto const& distances_arr = json["distances"].GetArray();
auto size = distances_arr.Size();
auto keys = primary_keys{};
for (auto idx = 0U; idx < size; ++idx) {
auto pk = pk_from_json(keys_json, idx, schema);
if (!pk) {
return std::unexpected{pk.error()};
}
auto ck = ck_from_json(keys_json, idx, schema);
if (!ck) {
return std::unexpected{ck.error()};
}
keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck});
}
return std::move(keys);
}
} // namespace
namespace service {
struct vector_store_client::impl {
lw_shared_ptr<http_client> current_client;
std::vector<lw_shared_ptr<http_client>> old_clients;
host_name host;
port_number port{};
inet_address addr;
time_point last_dns_refresh;
gate tasks_gate;
condition_variable refresh_cv;
condition_variable refresh_client_cv;
abort_source abort_refresh;
milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL;
milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT;
unsigned http_request_retries = HTTP_REQUEST_RETRIES;
std::function<future<std::optional<inet_address>>(sstring const&)> dns_resolver;
sequential_producer<lw_shared_ptr<http_client>> client_producer;
impl(host_name host_, port_number port_)
: host(std::move(host_))
, port(port_)
, dns_resolver([](auto const& host) -> future<std::optional<inet_address>> {
auto addr = co_await coroutine::as_future(net::dns::resolve_name(host));
if (addr.failed()) {
auto err = addr.get_exception();
if (try_catch<std::system_error>(err) != nullptr) {
co_return std::nullopt;
}
co_await coroutine::return_exception_ptr(std::move(err));
}
co_return co_await std::move(addr);
})
, client_producer([&]() -> future<lw_shared_ptr<http_client>> {
trigger_dns_refresh();
co_await wait_for_signal(refresh_client_cv, lowres_clock::now() + wait_for_client_timeout);
co_return current_client;
}) {
}
/// Refresh the http client with a new address resolved from the DNS name.
/// If the DNS resolution fails, the current client is set to nullptr.
/// If the address is the same as the current one, do nothing.
/// Old clients are saved for later cleanup in a specific task.
auto refresh_addr() -> future<> {
auto new_addr = co_await dns_resolver(host);
if (!new_addr) {
current_client = nullptr;
co_return;
}
// Check if the new address is the same as the current one
if (current_client && *new_addr == addr) {
co_return;
}
addr = *new_addr;
old_clients.emplace_back(current_client);
current_client = make_lw_shared<http_client>(socket_address(addr, port));
}
/// A task for refreshing the vector store http client.
auto refresh_addr_task() -> future<> {
for (;;) {
auto exception_occured = false;
try {
if (abort_refresh.abort_requested()) {
break;
}
// Do not refresh the service address too often
auto now = lowres_clock::now();
auto current_duration = now - last_dns_refresh;
if (current_duration > dns_refresh_interval) {
last_dns_refresh = now;
co_await refresh_addr();
} else {
// Wait till the end of the refreshing interval
if (co_await wait_for_timeout(dns_refresh_interval - current_duration, abort_refresh)) {
continue;
}
// If the wait was aborted, we stop refreshing
break;
}
if (abort_refresh.abort_requested()) {
break;
}
// new client is available
refresh_client_cv.broadcast();
co_await cleanup_old_clients();
co_await refresh_cv.when();
} catch (const std::exception& e) {
vslogger.error("Vector Store Client refresh task failed: {}", e.what());
exception_occured = true;
} catch (...) {
vslogger.error("Vector Store Client refresh task failed with unknown exception");
exception_occured = true;
}
if (exception_occured) {
// If an exception occurred, we wait for the next signal to refresh the address
co_await wait_for_timeout(EXCEPTION_OCCURED_WAIT, abort_refresh);
}
}
co_await cleanup_old_clients();
co_await cleanup_current_client();
}
/// Request a DNS refresh in the specific task.
void trigger_dns_refresh() {
refresh_cv.signal();
}
/// Cleanup current client
auto cleanup_current_client() -> future<> {
if (current_client) {
co_await current_client->close();
}
current_client = nullptr;
}
/// Cleanup old clients that are no longer used.
auto cleanup_old_clients() -> future<> {
// iterate over old clients and close them. There is a co_await in the loop
// so we need to use [] accessor and copying clients to avoid dangling references of iterators.
// NOLINTNEXTLINE(modernize-loop-convert)
for (auto it = 0U; it < old_clients.size(); ++it) {
auto& client = old_clients[it];
if (client && client.owned()) {
auto client_cloned = client;
co_await client_cloned->close();
client_cloned = nullptr;
}
}
std::erase_if(old_clients, [](auto const& client) {
return !client;
});
}
struct get_client_response {
lw_shared_ptr<http_client> client; ///< The http client.
host_name host; ///< The host name for the vector-store service.
};
using get_client_error = std::variant<aborted, addr_unavailable>;
/// Get the current http client or wait for a new one to be available.
auto get_client(abort_source& as) -> future<std::expected<get_client_response, get_client_error>> {
if (current_client) {
co_return get_client_response{.client = current_client, .host = host};
}
auto current_client = co_await coroutine::as_future(client_producer(as));
if (current_client.failed()) {
auto err = current_client.get_exception();
if (as.abort_requested()) {
co_return std::unexpected{aborted{}};
}
co_await coroutine::return_exception_ptr(std::move(err));
}
auto client = co_await std::move(current_client);
if (!client) {
co_return std::unexpected{addr_unavailable{}};
}
co_return get_client_response{.client = client, .host = host};
}
struct make_request_response {
http::reply::status_type status; ///< The HTTP status of the response.
std::vector<temporary_buffer<char>> content; ///< The content of the response.
};
using make_request_error = std::variant<aborted, addr_unavailable, service_unavailable>;
auto make_request(operation_type method, http_path path, std::optional<json_content> content, abort_source& as)
-> future<std::expected<make_request_response, make_request_error>> {
auto resp = make_request_response{.status = http::reply::status_type::ok, .content = std::vector<temporary_buffer<char>>()};
for (auto retries = 0; retries < HTTP_REQUEST_RETRIES; ++retries) {
auto client_host = co_await get_client(as);
if (!client_host) {
co_return std::unexpected{std::visit(
[](auto&& err) {
return make_request_error{err};
},
client_host.error())};
}
auto [client, host] = *std::move(client_host);
auto req = http::request::make(method, host, path);
if (content) {
req.write_body("json", *content);
}
auto result = co_await coroutine::as_future(client->make_request(
std::move(req),
[&resp](http::reply const& reply, input_stream<char> body) -> future<> {
resp.status = reply._status;
resp.content = co_await util::read_entire_stream(body);
},
std::nullopt, &as));
if (result.failed()) {
auto err = result.get_exception();
if (as.abort_requested()) {
co_return std::unexpected{aborted{}};
}
if (try_catch<std::system_error>(err) == nullptr) {
co_await coroutine::return_exception_ptr(std::move(err));
}
// std::system_error means that the server is unavailable, so we retry
} else {
co_return resp;
}
trigger_dns_refresh();
}
co_return std::unexpected{service_unavailable{}};
}
};
vector_store_client::vector_store_client(config const& cfg) {
auto config_uri = cfg.vector_store_uri();
if (config_uri.empty()) {
vslogger.info("Vector Store service URI is not configured.");
return;
}
auto parsed_uri = parse_service_uri(config_uri);
if (!parsed_uri) {
throw configuration_exception(format("Invalid Vector Store service URI: {}", config_uri));
}
auto [host, port] = *parsed_uri;
_impl = std::make_unique<impl>(std::move(host), port);
vslogger.info("Vector Store service uri = {}:{}.", _impl->host, _impl->port);
}
vector_store_client::~vector_store_client() = default;
void vector_store_client::start_background_tasks() {
if (is_disabled()) {
return;
}
/// start the background task to refresh the service address
(void)try_with_gate(_impl->tasks_gate, [this] {
return _impl->refresh_addr_task();
}).handle_exception([](std::exception_ptr eptr) {
on_internal_error_noexcept(vslogger, format("The Vector Store Client refresh task failed: {}", eptr));
});
}
auto vector_store_client::stop() -> future<> {
if (is_disabled()) {
co_return;
}
_impl->abort_refresh.request_abort();
_impl->refresh_cv.signal();
co_await _impl->tasks_gate.close();
}
auto vector_store_client::host() const -> std::expected<host_name, disabled> {
if (is_disabled()) {
return std::unexpected{disabled{}};
}
return {_impl->host};
}
auto vector_store_client::port() const -> std::expected<port_number, disabled> {
if (is_disabled()) {
return std::unexpected{disabled{}};
}
return {_impl->port};
}
auto vector_store_client::ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as)
-> future<std::expected<primary_keys, ann_error>> {
if (is_disabled()) {
vslogger.error("Disabled Vector Store while calling ann");
co_return std::unexpected{disabled{}};
}
auto path = format("/api/v1/indexes/{}/{}/ann", keyspace, name);
auto content = write_ann_json(std::move(embedding), limit);
auto resp = co_await _impl->make_request(operation_type::POST, std::move(path), std::move(content), as);
if (!resp) {
co_return std::unexpected{std::visit(
[](auto&& err) {
return ann_error{err};
},
resp.error())};
}
if (resp->status != status_type::ok) {
vslogger.error("Vector Store returned error: HTTP status {}: {}", resp->status, resp->content);
co_return std::unexpected{service_error{resp->status}};
}
try {
co_return read_ann_json(rjson::parse(std::move(resp->content)), schema);
} catch (const rjson::error& e) {
vslogger.error("Vector Store returned invalid JSON: {}", e.what());
co_return std::unexpected{service_reply_format_error{}};
}
}
void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client");
}
vsc._impl->dns_refresh_interval = interval;
}
void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot set wait_for_client_timeout on a disabled vector store client");
}
vsc._impl->wait_for_client_timeout = timeout;
}
void vector_store_client_tester::set_http_request_retries(vector_store_client& vsc, unsigned retries) {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot set http_request_retries on a disabled vector store client");
}
vsc._impl->http_request_retries = retries;
}
void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function<future<std::optional<inet_address>>(sstring const&)> resolver) {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client");
}
vsc._impl->dns_resolver = std::move(resolver);
}
void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot trigger a dns resolver on a disabled vector store client");
}
vsc._impl->trigger_dns_refresh();
}
auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::optional<inet_address>> {
if (vsc.is_disabled()) {
on_internal_error(vslogger, "Cannot check hostname resolving on a disabled vector store client");
}
auto client_host = co_await vsc._impl->get_client(as);
if (!client_host) {
co_return std::nullopt;
}
co_return vsc._impl->addr;
}
} // namespace service

View File

@@ -0,0 +1,112 @@
/*
* Copyright (C) 2025-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#pragma once
#include "seastarx.hh"
#include <seastar/core/shared_future.hh>
#include <seastar/core/shared_ptr.hh>
#include <seastar/http/reply.hh>
#include <expected>
class schema;
namespace cql3::statements {
class primary_key;
}
namespace db {
class config;
}
namespace seastar::net {
class inet_address;
}
namespace service {
/// A client with the vector-store service.
class vector_store_client final {
struct impl;
std::unique_ptr<impl> _impl;
public:
using config = db::config;
using embedding = std::vector<float>;
using host_name = sstring;
using index_name = sstring;
using keyspace_name = sstring;
using limit = std::size_t;
using port_number = std::uint16_t;
using primary_key = cql3::statements::primary_key;
using primary_keys = std::vector<primary_key>;
using schema_ptr = lw_shared_ptr<schema const>;
using status_type = http::reply::status_type;
/// The vector_store_client service is disabled.
struct disabled {};
/// The operation was aborted.
struct aborted {};
/// The vector-store addr is unavailable (not possible to get an addr from the dns service).
struct addr_unavailable {};
/// The vector-store service is unavailable.
struct service_unavailable {};
/// The error from the vector-store service.
struct service_error {
status_type status; ///< The HTTP status code from the vector-store service.
};
/// An unsupported reply format from the vector-store service.
struct service_reply_format_error {};
using ann_error = std::variant<disabled, aborted, addr_unavailable, service_unavailable, service_error, service_reply_format_error>;
explicit vector_store_client(config const& cfg);
~vector_store_client();
/// Start background tasks.
void start_background_tasks();
/// Stop the service.
auto stop() -> future<>;
/// Check if the vector_store_client is disabled.
auto is_disabled() const {
return !bool{_impl};
}
/// Get the current host name.
[[nodiscard]] auto host() const -> std::expected<host_name, disabled>;
/// Get the current port number.
[[nodiscard]] auto port() const -> std::expected<port_number, disabled>;
/// Request the vector store service for the primary keys of the nearest neighbors
auto ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as)
-> future<std::expected<primary_keys, ann_error>>;
private:
friend struct vector_store_client_tester;
};
/// A tester for the vector_store_client, used for testing purposes.
struct vector_store_client_tester {
static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval);
static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout);
static void set_http_request_retries(vector_store_client& vsc, unsigned retries);
static void set_dns_resolver(vector_store_client& vsc, std::function<future<std::optional<net::inet_address>>(sstring const&)> resolver);
static void trigger_dns_resolver(vector_store_client& vsc);
static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::optional<net::inet_address>>;
};
} // namespace service

View File

@@ -286,6 +286,8 @@ add_scylla_test(wrapping_interval_test
KIND BOOST)
add_scylla_test(address_map_test
KIND SEASTAR)
add_scylla_test(vector_store_client_test
KIND SEASTAR)
add_scylla_test(combined_tests
KIND SEASTAR

View File

@@ -0,0 +1,540 @@
/*
* Copyright (C) 2025-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "service/vector_store_client.hh"
#include "db/config.hh"
#include "exceptions/exceptions.hh"
#include "cql3/statements/select_statement.hh"
#include "test/lib/cql_test_env.hh"
#include "test/lib/log.hh"
#include <memory>
#include <seastar/core/shared_ptr.hh>
#include <seastar/net/api.hh>
#include <seastar/http/function_handlers.hh>
#include <seastar/http/httpd.hh>
#include <seastar/json/json_elements.hh>
#include <seastar/net/dns.hh>
#include <seastar/net/inet_address.hh>
#include <seastar/net/socket_defs.hh>
#include <seastar/testing/test_case.hh>
#include <seastar/testing/thread_test_case.hh>
#include <seastar/util/short_streams.hh>
#include <variant>
namespace {
using namespace seastar;
using vector_store_client = service::vector_store_client;
using vector_store_client_tester = service::vector_store_client_tester;
using config = vector_store_client::config;
using configuration_exception = exceptions::configuration_exception;
using inet_address = seastar::net::inet_address;
using function_handler = httpd::function_handler;
using http_server = httpd::http_server;
using http_server_tester = httpd::http_server_tester;
using milliseconds = std::chrono::milliseconds;
using operation_type = httpd::operation_type;
using port_number = vector_store_client::port_number;
using reply = http::reply;
using request = http::request;
using routes = httpd::routes;
using status_type = http::reply::status_type;
using url = httpd::url;
constexpr auto const* LOCALHOST = "127.0.0.1";
/// Generate an ephemeral port number for listening on localhost.
/// After closing this socket, the port should be not listened on for a while.
/// This is not guaranteed to be a robust solution, but it should work for most tests.
auto generate_unavailable_localhost_port() -> port_number {
auto inaddr = net::inet_address(LOCALHOST);
auto server = listen(socket_address(inaddr, 0));
auto port = server.local_address().port();
server.abort_accept();
return port;
}
auto listen_on_ephemeral_port(std::unique_ptr<http_server> server) -> future<std::tuple<std::unique_ptr<http_server>, socket_address>> {
auto inaddr = net::inet_address(LOCALHOST);
auto const addr = socket_address(inaddr, 0);
co_await server->listen(addr);
auto const& listeners = http_server_tester::listeners(*server);
BOOST_CHECK_EQUAL(listeners.size(), 1);
co_return std::make_tuple(std::move(server), listeners[0].local_address().port());
}
auto new_http_server(std::function<void(routes& r)> set_routes) -> future<std::tuple<std::unique_ptr<http_server>, socket_address>> {
auto server = std::make_unique<http_server>("test_vector_store_client");
set_routes(server->_routes);
server->set_content_streaming(true);
co_return co_await listen_on_ephemeral_port(std::move(server));
}
auto repeat_until(milliseconds timeout, std::function<future<bool>()> func) -> future<bool> {
auto begin = lowres_clock::now();
while (!co_await func()) {
if (lowres_clock::now() - begin > timeout) {
co_return false;
}
}
co_return true;
}
auto print_addr(const inet_address& addr) -> sstring {
return format("{}", addr);
}
} // namespace
BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) {
{
auto cfg = config();
auto vs = vector_store_client{cfg};
BOOST_CHECK(vs.is_disabled());
BOOST_CHECK(!vs.host());
BOOST_CHECK(!vs.port());
}
{
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.com:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
BOOST_CHECK_EQUAL(*vs.host(), "good.authority.com");
BOOST_CHECK_EQUAL(*vs.port(), 6080);
}
{
auto cfg = config();
cfg.vector_store_uri.set("http://bad,authority.com:6080");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
cfg.vector_store_uri.set("bad-schema://authority.com:6080");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
cfg.vector_store_uri.set("http://bad.port.com:a6080");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
cfg.vector_store_uri.set("http://bad.port.com:60806080");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
cfg.vector_store_uri.set("http://bad.format.com:60:80");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
cfg.vector_store_uri.set("http://authority.com:6080/bad/path");
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
}
}
/// Resolving of the hostname is started in start_background_tasks()
SEASTAR_TEST_CASE(vector_store_client_test_dns_started) {
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.here:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
auto as = abort_source();
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_REQUIRE(addr);
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
co_await vs.stop();
}
/// Unable to resolve the hostname
SEASTAR_TEST_CASE(vector_store_client_test_dns_resolve_failure) {
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.here:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_return std::nullopt;
});
vs.start_background_tasks();
auto as = abort_source();
BOOST_CHECK(!co_await vector_store_client_tester::resolve_hostname(vs, as));
co_await vs.stop();
}
/// Resolving of the hostname is repeated after errors
SEASTAR_TEST_CASE(vector_store_client_test_dns_resolving_repeated) {
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.here:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(20));
auto count = 0;
vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
count++;
if (count % 3 != 0) {
co_return std::nullopt;
}
co_return inet_address(format("127.0.0.{}", count));
});
vs.start_background_tasks();
auto as = abort_source();
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
co_return co_await vector_store_client_tester::resolve_hostname(vs, as);
}));
BOOST_CHECK_EQUAL(count, 3);
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_REQUIRE(addr);
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.3");
vector_store_client_tester::trigger_dns_resolver(vs);
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
co_return !co_await vector_store_client_tester::resolve_hostname(vs, as);
}));
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
co_return co_await vector_store_client_tester::resolve_hostname(vs, as);
}));
BOOST_CHECK_EQUAL(count, 6);
addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_REQUIRE(addr);
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.6");
co_await vs.stop();
}
/// Minimal interval between DNS refreshes is respected
SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_respects_interval) {
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.here:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
auto count = 0;
vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
count++;
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
co_await sleep(std::chrono::milliseconds(20)); // wait for the first DNS refresh
auto as = abort_source();
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_REQUIRE(addr);
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
BOOST_CHECK_EQUAL(count, 1);
count = 0;
vector_store_client_tester::trigger_dns_resolver(vs);
vector_store_client_tester::trigger_dns_resolver(vs);
vector_store_client_tester::trigger_dns_resolver(vs);
vector_store_client_tester::trigger_dns_resolver(vs);
vector_store_client_tester::trigger_dns_resolver(vs);
co_await sleep(std::chrono::milliseconds(100)); // wait for the next DNS refresh
addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_REQUIRE(addr);
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
BOOST_CHECK_GE(count, 1);
BOOST_CHECK_LE(count, 2);
co_await vs.stop();
}
/// DNS refresh could be aborted
SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) {
auto cfg = config();
cfg.vector_store_uri.set("http://good.authority.here:6080");
auto vs = vector_store_client{cfg};
BOOST_CHECK(!vs.is_disabled());
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_await sleep(std::chrono::milliseconds(100));
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
auto as = abort_source();
auto timeout = timer([&as]() {
as.request_abort();
});
timeout.arm(std::chrono::milliseconds(10));
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
BOOST_CHECK(!addr);
co_await vs.stop();
}
SEASTAR_TEST_CASE(vector_store_client_ann_test_disabled) {
co_await do_with_cql_env([](cql_test_env& env) -> future<> {
co_await env.execute_cql(R"(
create table ks.vs (
pk1 tinyint, pk2 tinyint,
ck1 tinyint, ck2 tinyint,
embedding vector<float, 3>,
primary key ((pk1, pk2), ck1, ck2))
)");
auto schema = env.local_db().find_schema("ks", "vs");
auto& vs = env.local_qp().vector_store_client();
auto as = abort_source();
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::disabled>(keys.error()));
});
}
SEASTAR_TEST_CASE(vector_store_client_test_ann_addr_unavailable) {
auto cfg = cql_test_config();
cfg.db_config->vector_store_uri.set("http://bad.authority.here:6080");
co_await do_with_cql_env(
[](cql_test_env& env) -> future<> {
co_await env.execute_cql(R"(
create table ks.vs (
pk1 tinyint, pk2 tinyint,
ck1 tinyint, ck2 tinyint,
embedding vector<float, 3>,
primary key ((pk1, pk2), ck1, ck2))
)");
auto schema = env.local_db().find_schema("ks", "vs");
auto& vs = env.local_qp().vector_store_client();
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_http_request_retries(vs, 3);
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "bad.authority.here");
co_return std::nullopt;
});
vs.start_background_tasks();
auto as = abort_source();
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::addr_unavailable>(keys.error()));
},
cfg);
}
SEASTAR_TEST_CASE(vector_store_client_test_ann_service_unavailable) {
auto cfg = cql_test_config();
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port()));
co_await do_with_cql_env(
[](cql_test_env& env) -> future<> {
co_await env.execute_cql(R"(
create table ks.vs (
pk1 tinyint, pk2 tinyint,
ck1 tinyint, ck2 tinyint,
embedding vector<float, 3>,
primary key ((pk1, pk2), ck1, ck2))
)");
auto schema = env.local_db().find_schema("ks", "vs");
auto& vs = env.local_qp().vector_store_client();
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_http_request_retries(vs, 3);
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
auto as = abort_source();
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_unavailable>(keys.error()));
},
cfg);
}
SEASTAR_TEST_CASE(vector_store_client_test_ann_service_aborted) {
auto cfg = cql_test_config();
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port()));
co_await do_with_cql_env(
[](cql_test_env& env) -> future<> {
co_await env.execute_cql(R"(
create table ks.vs (
pk1 tinyint, pk2 tinyint,
ck1 tinyint, ck2 tinyint,
embedding vector<float, 3>,
primary key ((pk1, pk2), ck1, ck2))
)");
auto schema = env.local_db().find_schema("ks", "vs");
auto& vs = env.local_qp().vector_store_client();
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_http_request_retries(vs, 3);
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_await sleep(std::chrono::milliseconds(100));
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
auto as = abort_source();
auto timeout = timer([&as]() {
as.request_abort();
});
timeout.arm(std::chrono::milliseconds(10));
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::aborted>(keys.error()));
},
cfg);
}
SEASTAR_TEST_CASE(vector_store_client_test_ann_request) {
auto ann_replies = make_lw_shared<std::queue<std::tuple<sstring, sstring>>>();
auto [server, addr] = co_await new_http_server([ann_replies](routes& r) {
auto ann = [ann_replies](std::unique_ptr<request> req, std::unique_ptr<reply> rep) -> future<std::unique_ptr<reply>> {
BOOST_REQUIRE(!ann_replies->empty());
auto [req_exp, rep_inp] = ann_replies->front();
auto const req_inp = co_await util::read_entire_stream_contiguous(*req->content_stream);
BOOST_CHECK_EQUAL(req_inp, req_exp);
ann_replies->pop();
rep->set_status(status_type::ok);
rep->write_body("json", rep_inp);
co_return rep;
};
r.add(operation_type::POST, url("/api/v1/indexes/ks/idx").remainder("ann"), new function_handler(ann, "json"));
});
auto cfg = cql_test_config();
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", addr.port()));
co_await do_with_cql_env(
[&ann_replies](cql_test_env& env) -> future<> {
co_await env.execute_cql(R"(
create table ks.vs (
pk1 tinyint, pk2 tinyint,
ck1 tinyint, ck2 tinyint,
embedding vector<float, 3>,
primary key ((pk1, pk2), ck1, ck2))
)");
auto schema = env.local_db().find_schema("ks", "vs");
auto& vs = env.local_qp().vector_store_client();
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
vector_store_client_tester::set_http_request_retries(vs, 3);
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
BOOST_CHECK_EQUAL(host, "good.authority.here");
co_return inet_address("127.0.0.1");
});
vs.start_background_tasks();
// set the wrong idx (wrong endpoint) - service should return 404
auto as = abort_source();
auto keys = co_await vs.ann("ks", "idx2", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
auto* err = std::get_if<vector_store_client::service_error>(&keys.error());
BOOST_CHECK(err != nullptr);
BOOST_CHECK_EQUAL(err->status, status_type::not_found);
// missing primary_keys in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys1":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
auto const now = lowres_clock::now();
for (;;) {
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
// if the service is unavailable or 400, retry, seems http server is not ready yet
auto* const unavailable = std::get_if<vector_store_client::service_unavailable>(&keys.error());
auto* const service_error = std::get_if<vector_store_client::service_error>(&keys.error());
if ((unavailable == nullptr && service_error == nullptr) ||
(service_error != nullptr && service_error->status != status_type::bad_request)) {
constexpr auto MAX_WAIT = std::chrono::seconds(5);
BOOST_REQUIRE(lowres_clock::now() - now < MAX_WAIT);
break;
}
}
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// missing distances in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances1":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// missing pk1 key in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk11":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// missing ck1 key in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck11":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// wrong size of pk2 key in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk1":[5,6],"pk2":[78],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// wrong size of ck2 key in the reply - service should return format error
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[23]},"distances":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(!keys);
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
// correct reply - service should return keys
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
BOOST_REQUIRE(keys);
BOOST_REQUIRE_EQUAL(keys->size(), 2);
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).partition.key().explode()), "[05, 07]");
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).clustering.explode()), "[09, 02]");
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).partition.key().explode()), "[06, 08]");
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).clustering.explode()), "[01, 03]");
},
cfg);
co_await server->stop();
}

View File

@@ -170,6 +170,7 @@ private:
sharded<gms::gossip_address_map> _gossip_address_map;
sharded<service::direct_fd_pinger> _fd_pinger;
sharded<cdc::cdc_service> _cdc;
sharded<service::vector_store_client> _vector_store_client;
db::config* _db_config;
service::raft_group0_client* _group0_client;
@@ -704,7 +705,12 @@ private:
std::chrono::duration_cast<std::chrono::milliseconds>(cql3::prepared_statements_cache::entry_expiry));
auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms());
_qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get();
_vector_store_client.start(std::ref(*cfg)).get();
auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [this] {
_vector_store_client.stop().get();
});
_qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), std::ref(_vector_store_client), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get();
auto stop_qp = defer_verbose_shutdown("query processor", [this] { _qp.stop().get(); });
_elc_notif.start().get();

View File

@@ -21,6 +21,7 @@ template<typename T>
class sequential_producer {
public:
using factory_t = std::function<seastar::future<T>()>;
using time_point = seastar::shared_future<T>::time_point;
private:
factory_t _factory;
@@ -32,11 +33,18 @@ class sequential_producer {
clear();
}
seastar::future<T> operator()() {
seastar::future<T> operator()(time_point timeout = time_point::max()) {
if (_churning.available()) {
_churning = _factory();
}
return _churning.get_future();
return _churning.get_future(timeout);
}
seastar::future<T> operator()(seastar::abort_source& as) {
if (_churning.available()) {
_churning = _factory();
}
return _churning.get_future(as);
}
void clear() {