mirror of
https://github.com/scylladb/scylladb.git
synced 2026-05-13 03:12:13 +00:00
Merge 'vector_store_client: implement vector_store_client service' from Pawel Pery
Vector Store service is a http server which provides vector search index and an ANN (Approximate Nearest Neighbor) functionality. Vector Store retrieves metadata & data from Scylla about indexes using CQL protocol & CDC functionality. Scylla will request ann search using http api. Commits for the patch: - implement initial `vector_store_client` service. It adds also a parameter `vector_store_uri` to the scylla. - refactor sequential_producer as abortable - implement ip addr retrieval from dns. The uri for Vector Store must contains dns name, this commit implements ip addr refreshing functionality - refactor primary_key as a top-level class. It is needed for the forward declaration of a primary_key - implement ANN API. It implements a core ANN search request functionality, adds Vector Store HTTP API description in docs/protocols.md, and implements automatic boost tests with mocked http server for checking error conditions. New feature, should not be backported. Fixes: VECTOR-47 Fixes: VECTOR-45 -~- Closes scylladb/scylladb#24331 * github.com:scylladb/scylladb: vector_store_client: implement ANN API cql3: refactor primary_key as a top-level class vector_store_client: implement ip addr retrieval from dns utils: refactor sequential_producer as abortable vector_store_client: implement initial vector_store_client service
This commit is contained in:
@@ -855,3 +855,10 @@ rf_rack_valid_keyspaces: false
|
||||
# Maximum number of items in single BatchWriteItem command. Default is 100.
|
||||
# Note: DynamoDB has a hard-coded limit of 25.
|
||||
# alternator_max_items_in_batch_write: 100
|
||||
|
||||
#
|
||||
# Vector Store options
|
||||
#
|
||||
# Uri for the vector store using dns name. Only http schema is supported. Port number is mandatory.
|
||||
# Default is empty, which means that the vector store is not used.
|
||||
# vector_store_uri: http://vector-store.dns.name:{port}
|
||||
|
||||
@@ -567,6 +567,7 @@ scylla_tests = set([
|
||||
'test/boost/symmetric_key_test',
|
||||
'test/boost/types_test',
|
||||
'test/boost/utf8_test',
|
||||
'test/boost/vector_store_client_test',
|
||||
'test/boost/vint_serialization_test',
|
||||
'test/boost/virtual_table_mutation_source_test',
|
||||
'test/boost/wasm_alloc_test',
|
||||
@@ -1219,6 +1220,7 @@ scylla_core = (['message/messaging_service.cc',
|
||||
'node_ops/task_manager_module.cc',
|
||||
'reader_concurrency_semaphore_group.cc',
|
||||
'utils/disk_space_monitor.cc',
|
||||
'service/vector_store_client.cc',
|
||||
] + [Antlr3Grammar('cql3/Cql.g')] \
|
||||
+ scylla_raft_core
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "cql3/untyped_result_set.hh"
|
||||
#include "db/config.hh"
|
||||
#include "data_dictionary/data_dictionary.hh"
|
||||
#include "service/vector_store_client.hh"
|
||||
#include "utils/hashers.hh"
|
||||
#include "utils/error_injection.hh"
|
||||
#include "service/migration_manager.hh"
|
||||
@@ -68,11 +69,12 @@ static service::query_state query_state_for_internal_call() {
|
||||
return {service::client_state::for_internal_calls(), empty_service_permit()};
|
||||
}
|
||||
|
||||
query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm)
|
||||
query_processor::query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, query_processor::memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm)
|
||||
: _migration_subscriber{std::make_unique<migration_subscriber>(this)}
|
||||
, _proxy(proxy)
|
||||
, _db(db)
|
||||
, _mnotifier(mn)
|
||||
, _vector_store_client(vsc)
|
||||
, _mcfg(mcfg)
|
||||
, _cql_config(cql_cfg)
|
||||
, _prepared_cache(prep_cache_log, _mcfg.prepared_statment_cache_size)
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "transport/messages/result_message.hh"
|
||||
#include "service/client_state.hh"
|
||||
#include "service/broadcast_tables/experimental/query_result.hh"
|
||||
#include "service/vector_store_client.hh"
|
||||
#include "utils/assert.hh"
|
||||
#include "utils/observable.hh"
|
||||
#include "service/raft/raft_group0_client.hh"
|
||||
@@ -107,6 +108,7 @@ private:
|
||||
service::storage_proxy& _proxy;
|
||||
data_dictionary::database _db;
|
||||
service::migration_notifier& _mnotifier;
|
||||
service::vector_store_client& _vector_store_client;
|
||||
memory_config _mcfg;
|
||||
const cql_config& _cql_config;
|
||||
|
||||
@@ -146,7 +148,7 @@ public:
|
||||
static std::unique_ptr<statements::raw::parsed_statement> parse_statement(const std::string_view& query, dialect d);
|
||||
static std::vector<std::unique_ptr<statements::raw::parsed_statement>> parse_statements(std::string_view queries, dialect d);
|
||||
|
||||
query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm);
|
||||
query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, service::vector_store_client& vsc, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm);
|
||||
|
||||
~query_processor();
|
||||
|
||||
@@ -176,6 +178,14 @@ public:
|
||||
|
||||
lang::manager& lang() { return _lang_manager; }
|
||||
|
||||
const service::vector_store_client& vector_store_client() const noexcept {
|
||||
return _vector_store_client;
|
||||
}
|
||||
|
||||
service::vector_store_client& vector_store_client() noexcept {
|
||||
return _vector_store_client;
|
||||
}
|
||||
|
||||
db::auth_version_t auth_version;
|
||||
|
||||
statements::prepared_statement::checked_weak_ptr get_prepared(const std::optional<auth::authenticated_user>& user, const prepared_cache_key_type& key) {
|
||||
|
||||
@@ -1471,10 +1471,10 @@ indexed_table_select_statement::find_index_partition_ranges(query_processor& qp,
|
||||
|
||||
// Note: the partitions keys returned by this function are sorted
|
||||
// in token order. See issue #3423.
|
||||
future<coordinator_result<std::tuple<std::vector<indexed_table_select_statement::primary_key>, lw_shared_ptr<const service::pager::paging_state>>>>
|
||||
future<coordinator_result<std::tuple<std::vector<primary_key>, lw_shared_ptr<const service::pager::paging_state>>>>
|
||||
indexed_table_select_statement::find_index_clustering_rows(query_processor& qp, service::query_state& state, const query_options& options) const
|
||||
{
|
||||
using value_type = std::tuple<std::vector<indexed_table_select_statement::primary_key>, lw_shared_ptr<const service::pager::paging_state>>;
|
||||
using value_type = std::tuple<std::vector<primary_key>, lw_shared_ptr<const service::pager::paging_state>>;
|
||||
auto now = gc_clock::now();
|
||||
auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options);
|
||||
const uint64_t limit = get_inner_loop_limit(get_limit(options, _limit), _selection->is_aggregate());
|
||||
|
||||
@@ -43,6 +43,13 @@ namespace restrictions {
|
||||
|
||||
namespace statements {
|
||||
|
||||
|
||||
/// Encapsulates a partition key and clustering key prefix as a primary key.
|
||||
struct primary_key {
|
||||
dht::decorated_key partition;
|
||||
clustering_key_prefix clustering;
|
||||
};
|
||||
|
||||
/**
|
||||
* Encapsulates a completely parsed SELECT query, including the target
|
||||
* column family, expression, result count, and ordering clause.
|
||||
@@ -134,12 +141,6 @@ public:
|
||||
const query_options& options, gc_clock::time_point now, int32_t page_size, bool aggregate, bool nonpaged_filtering, uint64_t limit,
|
||||
std::optional<service::cas_shard> cas_shard) const;
|
||||
|
||||
|
||||
struct primary_key {
|
||||
dht::decorated_key partition;
|
||||
clustering_key_prefix clustering;
|
||||
};
|
||||
|
||||
future<shared_ptr<cql_transport::messages::result_message>> process_results(foreign_ptr<lw_shared_ptr<query::result>> results,
|
||||
lw_shared_ptr<query::read_command> cmd, const query_options& options, gc_clock::time_point now) const;
|
||||
|
||||
|
||||
@@ -1361,6 +1361,7 @@ db::config::config(std::shared_ptr<db::extensions> exts)
|
||||
// alternator_max_items_in_batch_write matches DynamoDB behaviour of size limit, but with different value - for DynamoDB it's 25
|
||||
// (see DynamoDB's documentation for BatchWriteItem command)
|
||||
, alternator_max_items_in_batch_write(this, "alternator_max_items_in_batch_write", value_status::Used, 100, "Maximum amount of items in single BatchItemWrite call.")
|
||||
, vector_store_uri(this, "vector_store_uri", value_status::Used, "", "The URI of the vector store to use for vector search. If not set, vector search is disabled.")
|
||||
, abort_on_ebadf(this, "abort_on_ebadf", value_status::Used, true, "Abort the server on incorrect file descriptor access. Throws exception when disabled.")
|
||||
, redis_port(this, "redis_port", value_status::Used, 0, "Port on which the REDIS transport listens for clients.")
|
||||
, redis_ssl_port(this, "redis_ssl_port", value_status::Used, 0, "Port on which the REDIS TLS native transport listens for clients.")
|
||||
|
||||
@@ -488,6 +488,8 @@ public:
|
||||
named_value<sstring> alternator_describe_endpoints;
|
||||
named_value<uint32_t> alternator_max_items_in_batch_write;
|
||||
|
||||
named_value<sstring> vector_store_uri;
|
||||
|
||||
named_value<bool> abort_on_ebadf;
|
||||
|
||||
named_value<uint16_t> redis_port;
|
||||
|
||||
@@ -276,3 +276,48 @@ to Scylla's REST API port over the loopback (localhost) interface.
|
||||
The port on which scylla-jmx listens is by default port 7199. This port,
|
||||
and the listen address, can be overridden with the `-jp` and `-ja` options
|
||||
(respectively) of the `scylla-jmx` script.
|
||||
|
||||
## Vector Store
|
||||
|
||||
The Vector Search functionality within Scylla is provided by an external
|
||||
service [vector-store](https://github.com/scylla/vector-store). That service is
|
||||
responsible for creating and managing vector indexes built from data retrieved
|
||||
from Scylla using the CQL protocol and CDC functionality from a table and a
|
||||
custom index created in Scylla using `CREATE INDEX {keyspace}.{index} ON
|
||||
{table}({embedding) USING 'vector_index'`. Scylla does a vector search by
|
||||
delegating the search to the vector-store service using the HTTP API protocol.
|
||||
Scylla is the HTTP client of the vector-store service.
|
||||
|
||||
The supported vector-store HTTP API:
|
||||
|
||||
### `POST /api/v1/indexes/{keyspace}/{index}/ann`
|
||||
|
||||
This endpoint is for an ANN (Approximate Nearest Neighbor) search.
|
||||
|
||||
Parameters:
|
||||
- `keyspace`: The keyspace name of the index to search.
|
||||
- `index`: The index name to search.
|
||||
|
||||
Request Body:
|
||||
```json
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3, ...], // The vector to search for.
|
||||
"limit": 10 // The number of nearest neighbors to return.
|
||||
}
|
||||
```
|
||||
|
||||
Responses:
|
||||
|
||||
- 200 OK: Returns the nearest neighbors found.
|
||||
|
||||
Response Body (Structure of Arrays):
|
||||
```json
|
||||
{
|
||||
"distances": [0.1234, 0.5678, ...], // The distances to the nearest neighbors up to the limit provided by a request.
|
||||
"primary_keys": {
|
||||
"pk1_column_name": ["value1", "value2", ...], // The primary key values of the pk1_column_name with same size as "distances".
|
||||
"ck1_column_name": ["value1", "value2", ...], // The primary key values of the ck1_column_name with same size as "distances".
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
14
main.cc
14
main.cc
@@ -40,6 +40,7 @@
|
||||
#include "service/migration_manager.hh"
|
||||
#include "service/tablet_allocator.hh"
|
||||
#include "service/load_meter.hh"
|
||||
#include "service/vector_store_client.hh"
|
||||
#include "service/view_update_backlog_broker.hh"
|
||||
#include "service/qos/service_level_controller.hh"
|
||||
#include "streaming/stream_session.hh"
|
||||
@@ -740,6 +741,7 @@ sharded<locator::shared_token_metadata> token_metadata;
|
||||
sharded<service::mapreduce_service> mapreduce_service;
|
||||
sharded<gms::gossiper> gossiper;
|
||||
sharded<locator::snitch_ptr> snitch;
|
||||
sharded<service::vector_store_client> vector_store_client;
|
||||
|
||||
// This worker wasn't designed to be used from multiple threads.
|
||||
// If you are attempting to do that, make sure you know what you are doing.
|
||||
@@ -779,7 +781,8 @@ sharded<locator::shared_token_metadata> token_metadata;
|
||||
return seastar::async([&app, cfg, ext, &disk_space_monitor_shard0, &cm, &sstm, &db, &qp, &bm, &proxy, &mapreduce_service, &mm, &mm_notifier, &ctx, &opts, &dirs,
|
||||
&prometheus_server, &cf_cache_hitrate_calculator, &load_meter, &feature_service, &gossiper, &snitch,
|
||||
&token_metadata, &erm_factory, &snapshot_ctl, &messaging, &sst_dir_semaphore, &raft_gr, &service_memory_limiter,
|
||||
&repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker] {
|
||||
&repair, &sst_loader, &ss, &lifecycle_notifier, &stream_manager, &task_manager, &rpc_dict_training_worker,
|
||||
&vector_store_client] {
|
||||
try {
|
||||
if (opts.contains("relabel-config-file") && !opts["relabel-config-file"].as<sstring>().empty()) {
|
||||
// calling update_relabel_config_from_file can cause an exception that would stop startup
|
||||
@@ -1318,6 +1321,13 @@ sharded<locator::shared_token_metadata> token_metadata;
|
||||
static sharded<cql3::cql_config> cql_config;
|
||||
cql_config.start(std::ref(*cfg)).get();
|
||||
|
||||
checkpoint(stop_signal, "starting a vector store service");
|
||||
vector_store_client.start(std::ref(*cfg)).get();
|
||||
auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [&vector_store_client] {
|
||||
vector_store_client.stop().get();
|
||||
});
|
||||
vector_store_client.invoke_on_all(&service::vector_store_client::start_background_tasks).get();
|
||||
|
||||
checkpoint(stop_signal, "starting query processor");
|
||||
cql3::query_processor::memory_config qp_mcfg = {memory::stats().total_memory() / 256, memory::stats().total_memory() / 2560};
|
||||
debug::the_query_processor = &qp;
|
||||
@@ -1329,7 +1339,7 @@ sharded<locator::shared_token_metadata> token_metadata;
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(cql3::prepared_statements_cache::entry_expiry));
|
||||
auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms());
|
||||
|
||||
qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get();
|
||||
qp.start(std::ref(proxy), std::move(local_data_dict), std::ref(mm_notifier), std::ref(vector_store_client), qp_mcfg, std::ref(cql_config), std::move(auth_prep_cache_config), std::ref(langman)).get();
|
||||
|
||||
checkpoint(stop_signal, "starting lifecycle notifier");
|
||||
lifecycle_notifier.start().get();
|
||||
|
||||
@@ -33,7 +33,8 @@ target_sources(service
|
||||
task_manager_module.cc
|
||||
topology_coordinator.cc
|
||||
topology_mutation.cc
|
||||
topology_state_machine.cc)
|
||||
topology_state_machine.cc
|
||||
vector_store_client.cc)
|
||||
target_include_directories(service
|
||||
PUBLIC
|
||||
${CMAKE_SOURCE_DIR})
|
||||
|
||||
572
service/vector_store_client.cc
Normal file
572
service/vector_store_client.cc
Normal file
@@ -0,0 +1,572 @@
|
||||
/*
|
||||
* Copyright (C) 2025-present ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
||||
*/
|
||||
|
||||
#include "vector_store_client.hh"
|
||||
#include "cql3/statements/select_statement.hh"
|
||||
#include "cql3/type_json.hh"
|
||||
#include "db/config.hh"
|
||||
#include "exceptions/exceptions.hh"
|
||||
#include "utils/sequential_producer.hh"
|
||||
#include "dht/i_partitioner.hh"
|
||||
#include "keys.hh"
|
||||
#include "utils/rjson.hh"
|
||||
#include "schema/schema.hh"
|
||||
#include <charconv>
|
||||
#include <exception>
|
||||
#include <fmt/ranges.h>
|
||||
#include <regex>
|
||||
#include <seastar/coroutine/as_future.hh>
|
||||
#include <seastar/coroutine/exception.hh>
|
||||
#include <seastar/http/client.hh>
|
||||
#include <seastar/http/request.hh>
|
||||
#include <seastar/net/dns.hh>
|
||||
#include <seastar/net/inet_address.hh>
|
||||
#include <seastar/net/socket_defs.hh>
|
||||
#include <seastar/util/short_streams.hh>
|
||||
|
||||
namespace {
|
||||
|
||||
using ann_error = service::vector_store_client::ann_error;
|
||||
using configuration_exception = exceptions::configuration_exception;
|
||||
using duration = lowres_clock::duration;
|
||||
using embedding = service::vector_store_client::embedding;
|
||||
using limit = service::vector_store_client::limit;
|
||||
using host_name = service::vector_store_client::host_name;
|
||||
using http_client = http::experimental::client;
|
||||
using http_path = sstring;
|
||||
using inet_address = seastar::net::inet_address;
|
||||
using json_content = sstring;
|
||||
using milliseconds = std::chrono::milliseconds;
|
||||
using operation_type = httpd::operation_type;
|
||||
using port_number = service::vector_store_client::port_number;
|
||||
using primary_key = service::vector_store_client::primary_key;
|
||||
using primary_keys = service::vector_store_client::primary_keys;
|
||||
using service_reply_format_error = service::vector_store_client::service_reply_format_error;
|
||||
using time_point = lowres_clock::time_point;
|
||||
|
||||
// Wait time before retrying after an exception occurred
|
||||
constexpr auto EXCEPTION_OCCURED_WAIT = std::chrono::seconds(5);
|
||||
|
||||
// Minimum interval between dns name refreshes
|
||||
constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5);
|
||||
|
||||
/// Timeout for waiting for a new client to be available
|
||||
constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5);
|
||||
|
||||
/// How many retries to do for HTTP requests
|
||||
constexpr auto HTTP_REQUEST_RETRIES = 3;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
logging::logger vslogger("vector_store_client");
|
||||
|
||||
auto parse_port(std::string const& port_txt) -> std::optional<port_number> {
|
||||
auto port = port_number{};
|
||||
auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port);
|
||||
if (*ptr != '\0' || ec != std::errc{}) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return port;
|
||||
}
|
||||
|
||||
auto parse_service_uri(std::string_view uri) -> std::optional<std::tuple<host_name, port_number>> {
|
||||
constexpr auto URI_REGEX = R"(^http:\/\/([a-z0-9._-]+):([0-9]+)$)";
|
||||
auto const uri_regex = std::regex(URI_REGEX);
|
||||
auto uri_match = std::smatch{};
|
||||
auto uri_txt = std::string(uri);
|
||||
if (!std::regex_match(uri_txt, uri_match, uri_regex) || uri_match.size() != 3) {
|
||||
return {};
|
||||
}
|
||||
auto host = uri_match[1].str();
|
||||
auto port = parse_port(uri_match[2].str());
|
||||
if (!port) {
|
||||
return {};
|
||||
}
|
||||
return {{host, *port}};
|
||||
}
|
||||
|
||||
/// Wait for a timeout ar abort signal.
|
||||
auto wait_for_timeout(duration timeout, abort_source& as) -> future<bool> {
|
||||
auto result = co_await coroutine::as_future(sleep_abortable(timeout, as));
|
||||
if (result.failed()) {
|
||||
auto err = result.get_exception();
|
||||
if (as.abort_requested()) {
|
||||
co_return false;
|
||||
}
|
||||
co_await coroutine::return_exception_ptr(std::move(err));
|
||||
}
|
||||
co_return true;
|
||||
}
|
||||
|
||||
/// Wait for a condition variable to be signaled or timeout.
|
||||
auto wait_for_signal(condition_variable& cv, time_point timeout) -> future<bool> {
|
||||
auto result = co_await coroutine::as_future(cv.wait(timeout));
|
||||
if (result.failed()) {
|
||||
auto err = result.get_exception();
|
||||
if (try_catch<condition_variable_timed_out>(err) != nullptr) {
|
||||
co_return false;
|
||||
}
|
||||
co_await coroutine::return_exception_ptr(std::move(err));
|
||||
}
|
||||
co_return true;
|
||||
}
|
||||
|
||||
auto get_key_column_value(const rjson::value& item, std::size_t idx, const column_definition& column) -> std::expected<bytes, ann_error> {
|
||||
auto const& column_name = column.name_as_text();
|
||||
auto const* keys_obj = rjson::find(item, column_name);
|
||||
if (keys_obj == nullptr) {
|
||||
vslogger.error("Vector Store returned invalid JSON: missing key column '{}'", column_name);
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
if (!keys_obj->IsArray()) {
|
||||
vslogger.error("Vector Store returned invalid JSON: key column '{}' is not an array", column_name);
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
auto const& keys_arr = keys_obj->GetArray();
|
||||
if (keys_arr.Size() <= idx) {
|
||||
vslogger.error("Vector Store returned invalid JSON: key column '{}' array too small", column_name);
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
auto const& key = keys_arr[idx];
|
||||
return from_json_object(*column.type, key);
|
||||
}
|
||||
|
||||
auto pk_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<partition_key, ann_error> {
|
||||
std::vector<bytes> raw_pk;
|
||||
for (const column_definition& cdef : schema->partition_key_columns()) {
|
||||
auto raw_value = get_key_column_value(item, idx, cdef);
|
||||
if (!raw_value) {
|
||||
return std::unexpected{raw_value.error()};
|
||||
}
|
||||
raw_pk.emplace_back(*raw_value);
|
||||
}
|
||||
return partition_key::from_exploded(raw_pk);
|
||||
}
|
||||
|
||||
auto ck_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<clustering_key_prefix, ann_error> {
|
||||
if (schema->clustering_key_size() == 0) {
|
||||
return clustering_key_prefix::make_empty();
|
||||
}
|
||||
|
||||
std::vector<bytes> raw_ck;
|
||||
for (const column_definition& cdef : schema->clustering_key_columns()) {
|
||||
auto raw_value = get_key_column_value(item, idx, cdef);
|
||||
if (!raw_value) {
|
||||
return std::unexpected{raw_value.error()};
|
||||
}
|
||||
raw_ck.emplace_back(*raw_value);
|
||||
}
|
||||
|
||||
return clustering_key_prefix::from_exploded(raw_ck);
|
||||
}
|
||||
|
||||
auto write_ann_json(embedding embedding, limit limit) -> json_content {
|
||||
return seastar::format(R"({{"embedding":[{}],"limit":{}}})", fmt::join(embedding, ","), limit);
|
||||
}
|
||||
|
||||
auto read_ann_json(rjson::value const& json, schema_ptr const& schema) -> std::expected<primary_keys, ann_error> {
|
||||
if (!json.HasMember("primary_keys")) {
|
||||
vslogger.error("Vector Store returned invalid JSON: missing 'primary_keys'");
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
auto const& keys_json = json["primary_keys"];
|
||||
if (!keys_json.IsObject()) {
|
||||
vslogger.error("Vector Store returned invalid JSON: 'primary_keys' is not an object");
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
|
||||
if (!json.HasMember("distances")) {
|
||||
vslogger.error("Vector Store returned invalid JSON: missing 'distances'");
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
auto const& distances_json = json["distances"];
|
||||
if (!distances_json.IsArray()) {
|
||||
vslogger.error("Vector Store returned invalid JSON: 'distances' is not an array");
|
||||
return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
auto const& distances_arr = json["distances"].GetArray();
|
||||
|
||||
auto size = distances_arr.Size();
|
||||
auto keys = primary_keys{};
|
||||
for (auto idx = 0U; idx < size; ++idx) {
|
||||
auto pk = pk_from_json(keys_json, idx, schema);
|
||||
if (!pk) {
|
||||
return std::unexpected{pk.error()};
|
||||
}
|
||||
auto ck = ck_from_json(keys_json, idx, schema);
|
||||
if (!ck) {
|
||||
return std::unexpected{ck.error()};
|
||||
}
|
||||
keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck});
|
||||
}
|
||||
return std::move(keys);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace service {
|
||||
|
||||
struct vector_store_client::impl {
|
||||
lw_shared_ptr<http_client> current_client;
|
||||
std::vector<lw_shared_ptr<http_client>> old_clients;
|
||||
host_name host;
|
||||
port_number port{};
|
||||
inet_address addr;
|
||||
time_point last_dns_refresh;
|
||||
gate tasks_gate;
|
||||
condition_variable refresh_cv;
|
||||
condition_variable refresh_client_cv;
|
||||
abort_source abort_refresh;
|
||||
milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL;
|
||||
milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT;
|
||||
unsigned http_request_retries = HTTP_REQUEST_RETRIES;
|
||||
std::function<future<std::optional<inet_address>>(sstring const&)> dns_resolver;
|
||||
sequential_producer<lw_shared_ptr<http_client>> client_producer;
|
||||
|
||||
impl(host_name host_, port_number port_)
|
||||
: host(std::move(host_))
|
||||
, port(port_)
|
||||
, dns_resolver([](auto const& host) -> future<std::optional<inet_address>> {
|
||||
auto addr = co_await coroutine::as_future(net::dns::resolve_name(host));
|
||||
if (addr.failed()) {
|
||||
auto err = addr.get_exception();
|
||||
if (try_catch<std::system_error>(err) != nullptr) {
|
||||
co_return std::nullopt;
|
||||
}
|
||||
co_await coroutine::return_exception_ptr(std::move(err));
|
||||
}
|
||||
co_return co_await std::move(addr);
|
||||
})
|
||||
, client_producer([&]() -> future<lw_shared_ptr<http_client>> {
|
||||
trigger_dns_refresh();
|
||||
co_await wait_for_signal(refresh_client_cv, lowres_clock::now() + wait_for_client_timeout);
|
||||
co_return current_client;
|
||||
}) {
|
||||
}
|
||||
|
||||
/// Refresh the http client with a new address resolved from the DNS name.
|
||||
/// If the DNS resolution fails, the current client is set to nullptr.
|
||||
/// If the address is the same as the current one, do nothing.
|
||||
/// Old clients are saved for later cleanup in a specific task.
|
||||
auto refresh_addr() -> future<> {
|
||||
auto new_addr = co_await dns_resolver(host);
|
||||
if (!new_addr) {
|
||||
current_client = nullptr;
|
||||
co_return;
|
||||
}
|
||||
|
||||
// Check if the new address is the same as the current one
|
||||
if (current_client && *new_addr == addr) {
|
||||
co_return;
|
||||
}
|
||||
|
||||
addr = *new_addr;
|
||||
old_clients.emplace_back(current_client);
|
||||
current_client = make_lw_shared<http_client>(socket_address(addr, port));
|
||||
}
|
||||
|
||||
/// A task for refreshing the vector store http client.
|
||||
auto refresh_addr_task() -> future<> {
|
||||
for (;;) {
|
||||
auto exception_occured = false;
|
||||
try {
|
||||
if (abort_refresh.abort_requested()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Do not refresh the service address too often
|
||||
auto now = lowres_clock::now();
|
||||
auto current_duration = now - last_dns_refresh;
|
||||
if (current_duration > dns_refresh_interval) {
|
||||
last_dns_refresh = now;
|
||||
co_await refresh_addr();
|
||||
} else {
|
||||
// Wait till the end of the refreshing interval
|
||||
if (co_await wait_for_timeout(dns_refresh_interval - current_duration, abort_refresh)) {
|
||||
continue;
|
||||
}
|
||||
// If the wait was aborted, we stop refreshing
|
||||
break;
|
||||
}
|
||||
|
||||
if (abort_refresh.abort_requested()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// new client is available
|
||||
refresh_client_cv.broadcast();
|
||||
|
||||
co_await cleanup_old_clients();
|
||||
|
||||
co_await refresh_cv.when();
|
||||
} catch (const std::exception& e) {
|
||||
vslogger.error("Vector Store Client refresh task failed: {}", e.what());
|
||||
exception_occured = true;
|
||||
} catch (...) {
|
||||
vslogger.error("Vector Store Client refresh task failed with unknown exception");
|
||||
exception_occured = true;
|
||||
}
|
||||
if (exception_occured) {
|
||||
// If an exception occurred, we wait for the next signal to refresh the address
|
||||
co_await wait_for_timeout(EXCEPTION_OCCURED_WAIT, abort_refresh);
|
||||
}
|
||||
}
|
||||
|
||||
co_await cleanup_old_clients();
|
||||
co_await cleanup_current_client();
|
||||
}
|
||||
|
||||
/// Request a DNS refresh in the specific task.
|
||||
void trigger_dns_refresh() {
|
||||
refresh_cv.signal();
|
||||
}
|
||||
|
||||
/// Cleanup current client
|
||||
auto cleanup_current_client() -> future<> {
|
||||
if (current_client) {
|
||||
co_await current_client->close();
|
||||
}
|
||||
current_client = nullptr;
|
||||
}
|
||||
|
||||
/// Cleanup old clients that are no longer used.
|
||||
auto cleanup_old_clients() -> future<> {
|
||||
// iterate over old clients and close them. There is a co_await in the loop
|
||||
// so we need to use [] accessor and copying clients to avoid dangling references of iterators.
|
||||
// NOLINTNEXTLINE(modernize-loop-convert)
|
||||
for (auto it = 0U; it < old_clients.size(); ++it) {
|
||||
auto& client = old_clients[it];
|
||||
if (client && client.owned()) {
|
||||
auto client_cloned = client;
|
||||
co_await client_cloned->close();
|
||||
client_cloned = nullptr;
|
||||
}
|
||||
}
|
||||
std::erase_if(old_clients, [](auto const& client) {
|
||||
return !client;
|
||||
});
|
||||
}
|
||||
|
||||
struct get_client_response {
|
||||
lw_shared_ptr<http_client> client; ///< The http client.
|
||||
host_name host; ///< The host name for the vector-store service.
|
||||
};
|
||||
|
||||
using get_client_error = std::variant<aborted, addr_unavailable>;
|
||||
|
||||
/// Get the current http client or wait for a new one to be available.
|
||||
auto get_client(abort_source& as) -> future<std::expected<get_client_response, get_client_error>> {
|
||||
if (current_client) {
|
||||
co_return get_client_response{.client = current_client, .host = host};
|
||||
}
|
||||
|
||||
auto current_client = co_await coroutine::as_future(client_producer(as));
|
||||
|
||||
if (current_client.failed()) {
|
||||
auto err = current_client.get_exception();
|
||||
if (as.abort_requested()) {
|
||||
co_return std::unexpected{aborted{}};
|
||||
}
|
||||
co_await coroutine::return_exception_ptr(std::move(err));
|
||||
}
|
||||
auto client = co_await std::move(current_client);
|
||||
if (!client) {
|
||||
co_return std::unexpected{addr_unavailable{}};
|
||||
}
|
||||
co_return get_client_response{.client = client, .host = host};
|
||||
}
|
||||
|
||||
struct make_request_response {
|
||||
http::reply::status_type status; ///< The HTTP status of the response.
|
||||
std::vector<temporary_buffer<char>> content; ///< The content of the response.
|
||||
};
|
||||
|
||||
using make_request_error = std::variant<aborted, addr_unavailable, service_unavailable>;
|
||||
|
||||
auto make_request(operation_type method, http_path path, std::optional<json_content> content, abort_source& as)
|
||||
-> future<std::expected<make_request_response, make_request_error>> {
|
||||
auto resp = make_request_response{.status = http::reply::status_type::ok, .content = std::vector<temporary_buffer<char>>()};
|
||||
|
||||
for (auto retries = 0; retries < HTTP_REQUEST_RETRIES; ++retries) {
|
||||
auto client_host = co_await get_client(as);
|
||||
if (!client_host) {
|
||||
co_return std::unexpected{std::visit(
|
||||
[](auto&& err) {
|
||||
return make_request_error{err};
|
||||
},
|
||||
client_host.error())};
|
||||
}
|
||||
auto [client, host] = *std::move(client_host);
|
||||
|
||||
auto req = http::request::make(method, host, path);
|
||||
if (content) {
|
||||
req.write_body("json", *content);
|
||||
}
|
||||
|
||||
auto result = co_await coroutine::as_future(client->make_request(
|
||||
std::move(req),
|
||||
[&resp](http::reply const& reply, input_stream<char> body) -> future<> {
|
||||
resp.status = reply._status;
|
||||
resp.content = co_await util::read_entire_stream(body);
|
||||
},
|
||||
std::nullopt, &as));
|
||||
if (result.failed()) {
|
||||
auto err = result.get_exception();
|
||||
if (as.abort_requested()) {
|
||||
co_return std::unexpected{aborted{}};
|
||||
}
|
||||
if (try_catch<std::system_error>(err) == nullptr) {
|
||||
co_await coroutine::return_exception_ptr(std::move(err));
|
||||
}
|
||||
// std::system_error means that the server is unavailable, so we retry
|
||||
} else {
|
||||
co_return resp;
|
||||
}
|
||||
|
||||
trigger_dns_refresh();
|
||||
}
|
||||
|
||||
co_return std::unexpected{service_unavailable{}};
|
||||
}
|
||||
};
|
||||
|
||||
vector_store_client::vector_store_client(config const& cfg) {
|
||||
auto config_uri = cfg.vector_store_uri();
|
||||
if (config_uri.empty()) {
|
||||
vslogger.info("Vector Store service URI is not configured.");
|
||||
return;
|
||||
}
|
||||
|
||||
auto parsed_uri = parse_service_uri(config_uri);
|
||||
if (!parsed_uri) {
|
||||
throw configuration_exception(format("Invalid Vector Store service URI: {}", config_uri));
|
||||
}
|
||||
|
||||
auto [host, port] = *parsed_uri;
|
||||
_impl = std::make_unique<impl>(std::move(host), port);
|
||||
vslogger.info("Vector Store service uri = {}:{}.", _impl->host, _impl->port);
|
||||
}
|
||||
|
||||
vector_store_client::~vector_store_client() = default;
|
||||
|
||||
void vector_store_client::start_background_tasks() {
|
||||
if (is_disabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
/// start the background task to refresh the service address
|
||||
(void)try_with_gate(_impl->tasks_gate, [this] {
|
||||
return _impl->refresh_addr_task();
|
||||
}).handle_exception([](std::exception_ptr eptr) {
|
||||
on_internal_error_noexcept(vslogger, format("The Vector Store Client refresh task failed: {}", eptr));
|
||||
});
|
||||
}
|
||||
|
||||
auto vector_store_client::stop() -> future<> {
|
||||
if (is_disabled()) {
|
||||
co_return;
|
||||
}
|
||||
|
||||
_impl->abort_refresh.request_abort();
|
||||
_impl->refresh_cv.signal();
|
||||
co_await _impl->tasks_gate.close();
|
||||
}
|
||||
|
||||
auto vector_store_client::host() const -> std::expected<host_name, disabled> {
|
||||
if (is_disabled()) {
|
||||
return std::unexpected{disabled{}};
|
||||
}
|
||||
return {_impl->host};
|
||||
}
|
||||
|
||||
auto vector_store_client::port() const -> std::expected<port_number, disabled> {
|
||||
if (is_disabled()) {
|
||||
return std::unexpected{disabled{}};
|
||||
}
|
||||
return {_impl->port};
|
||||
}
|
||||
|
||||
auto vector_store_client::ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as)
|
||||
-> future<std::expected<primary_keys, ann_error>> {
|
||||
if (is_disabled()) {
|
||||
vslogger.error("Disabled Vector Store while calling ann");
|
||||
co_return std::unexpected{disabled{}};
|
||||
}
|
||||
|
||||
auto path = format("/api/v1/indexes/{}/{}/ann", keyspace, name);
|
||||
auto content = write_ann_json(std::move(embedding), limit);
|
||||
|
||||
auto resp = co_await _impl->make_request(operation_type::POST, std::move(path), std::move(content), as);
|
||||
if (!resp) {
|
||||
co_return std::unexpected{std::visit(
|
||||
[](auto&& err) {
|
||||
return ann_error{err};
|
||||
},
|
||||
resp.error())};
|
||||
}
|
||||
|
||||
if (resp->status != status_type::ok) {
|
||||
vslogger.error("Vector Store returned error: HTTP status {}: {}", resp->status, resp->content);
|
||||
co_return std::unexpected{service_error{resp->status}};
|
||||
}
|
||||
|
||||
try {
|
||||
co_return read_ann_json(rjson::parse(std::move(resp->content)), schema);
|
||||
} catch (const rjson::error& e) {
|
||||
vslogger.error("Vector Store returned invalid JSON: {}", e.what());
|
||||
co_return std::unexpected{service_reply_format_error{}};
|
||||
}
|
||||
}
|
||||
|
||||
void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client");
|
||||
}
|
||||
vsc._impl->dns_refresh_interval = interval;
|
||||
}
|
||||
|
||||
void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot set wait_for_client_timeout on a disabled vector store client");
|
||||
}
|
||||
vsc._impl->wait_for_client_timeout = timeout;
|
||||
}
|
||||
|
||||
void vector_store_client_tester::set_http_request_retries(vector_store_client& vsc, unsigned retries) {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot set http_request_retries on a disabled vector store client");
|
||||
}
|
||||
vsc._impl->http_request_retries = retries;
|
||||
}
|
||||
|
||||
void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function<future<std::optional<inet_address>>(sstring const&)> resolver) {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client");
|
||||
}
|
||||
vsc._impl->dns_resolver = std::move(resolver);
|
||||
}
|
||||
|
||||
void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot trigger a dns resolver on a disabled vector store client");
|
||||
}
|
||||
vsc._impl->trigger_dns_refresh();
|
||||
}
|
||||
|
||||
auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::optional<inet_address>> {
|
||||
if (vsc.is_disabled()) {
|
||||
on_internal_error(vslogger, "Cannot check hostname resolving on a disabled vector store client");
|
||||
}
|
||||
auto client_host = co_await vsc._impl->get_client(as);
|
||||
if (!client_host) {
|
||||
co_return std::nullopt;
|
||||
}
|
||||
co_return vsc._impl->addr;
|
||||
}
|
||||
|
||||
} // namespace service
|
||||
|
||||
112
service/vector_store_client.hh
Normal file
112
service/vector_store_client.hh
Normal file
@@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright (C) 2025-present ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "seastarx.hh"
|
||||
#include <seastar/core/shared_future.hh>
|
||||
#include <seastar/core/shared_ptr.hh>
|
||||
#include <seastar/http/reply.hh>
|
||||
#include <expected>
|
||||
|
||||
class schema;
|
||||
|
||||
namespace cql3::statements {
|
||||
class primary_key;
|
||||
}
|
||||
|
||||
namespace db {
|
||||
class config;
|
||||
}
|
||||
|
||||
namespace seastar::net {
|
||||
class inet_address;
|
||||
}
|
||||
|
||||
namespace service {
|
||||
|
||||
/// A client with the vector-store service.
|
||||
class vector_store_client final {
|
||||
struct impl;
|
||||
std::unique_ptr<impl> _impl;
|
||||
|
||||
public:
|
||||
using config = db::config;
|
||||
using embedding = std::vector<float>;
|
||||
using host_name = sstring;
|
||||
using index_name = sstring;
|
||||
using keyspace_name = sstring;
|
||||
using limit = std::size_t;
|
||||
using port_number = std::uint16_t;
|
||||
using primary_key = cql3::statements::primary_key;
|
||||
using primary_keys = std::vector<primary_key>;
|
||||
using schema_ptr = lw_shared_ptr<schema const>;
|
||||
using status_type = http::reply::status_type;
|
||||
|
||||
/// The vector_store_client service is disabled.
|
||||
struct disabled {};
|
||||
|
||||
/// The operation was aborted.
|
||||
struct aborted {};
|
||||
|
||||
/// The vector-store addr is unavailable (not possible to get an addr from the dns service).
|
||||
struct addr_unavailable {};
|
||||
|
||||
/// The vector-store service is unavailable.
|
||||
struct service_unavailable {};
|
||||
|
||||
/// The error from the vector-store service.
|
||||
struct service_error {
|
||||
status_type status; ///< The HTTP status code from the vector-store service.
|
||||
};
|
||||
|
||||
/// An unsupported reply format from the vector-store service.
|
||||
struct service_reply_format_error {};
|
||||
|
||||
using ann_error = std::variant<disabled, aborted, addr_unavailable, service_unavailable, service_error, service_reply_format_error>;
|
||||
|
||||
explicit vector_store_client(config const& cfg);
|
||||
~vector_store_client();
|
||||
|
||||
/// Start background tasks.
|
||||
void start_background_tasks();
|
||||
|
||||
/// Stop the service.
|
||||
auto stop() -> future<>;
|
||||
|
||||
/// Check if the vector_store_client is disabled.
|
||||
auto is_disabled() const {
|
||||
return !bool{_impl};
|
||||
}
|
||||
|
||||
/// Get the current host name.
|
||||
[[nodiscard]] auto host() const -> std::expected<host_name, disabled>;
|
||||
|
||||
/// Get the current port number.
|
||||
[[nodiscard]] auto port() const -> std::expected<port_number, disabled>;
|
||||
|
||||
/// Request the vector store service for the primary keys of the nearest neighbors
|
||||
auto ann(keyspace_name keyspace, index_name name, schema_ptr schema, embedding embedding, limit limit, abort_source& as)
|
||||
-> future<std::expected<primary_keys, ann_error>>;
|
||||
|
||||
private:
|
||||
friend struct vector_store_client_tester;
|
||||
};
|
||||
|
||||
/// A tester for the vector_store_client, used for testing purposes.
|
||||
struct vector_store_client_tester {
|
||||
static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval);
|
||||
static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout);
|
||||
static void set_http_request_retries(vector_store_client& vsc, unsigned retries);
|
||||
static void set_dns_resolver(vector_store_client& vsc, std::function<future<std::optional<net::inet_address>>(sstring const&)> resolver);
|
||||
static void trigger_dns_resolver(vector_store_client& vsc);
|
||||
static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::optional<net::inet_address>>;
|
||||
};
|
||||
|
||||
} // namespace service
|
||||
|
||||
@@ -286,6 +286,8 @@ add_scylla_test(wrapping_interval_test
|
||||
KIND BOOST)
|
||||
add_scylla_test(address_map_test
|
||||
KIND SEASTAR)
|
||||
add_scylla_test(vector_store_client_test
|
||||
KIND SEASTAR)
|
||||
|
||||
add_scylla_test(combined_tests
|
||||
KIND SEASTAR
|
||||
|
||||
540
test/boost/vector_store_client_test.cc
Normal file
540
test/boost/vector_store_client_test.cc
Normal file
@@ -0,0 +1,540 @@
|
||||
/*
|
||||
* Copyright (C) 2025-present ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
||||
*/
|
||||
|
||||
#include "service/vector_store_client.hh"
|
||||
#include "db/config.hh"
|
||||
#include "exceptions/exceptions.hh"
|
||||
#include "cql3/statements/select_statement.hh"
|
||||
#include "test/lib/cql_test_env.hh"
|
||||
#include "test/lib/log.hh"
|
||||
#include <memory>
|
||||
#include <seastar/core/shared_ptr.hh>
|
||||
#include <seastar/net/api.hh>
|
||||
#include <seastar/http/function_handlers.hh>
|
||||
#include <seastar/http/httpd.hh>
|
||||
#include <seastar/json/json_elements.hh>
|
||||
#include <seastar/net/dns.hh>
|
||||
#include <seastar/net/inet_address.hh>
|
||||
#include <seastar/net/socket_defs.hh>
|
||||
#include <seastar/testing/test_case.hh>
|
||||
#include <seastar/testing/thread_test_case.hh>
|
||||
#include <seastar/util/short_streams.hh>
|
||||
#include <variant>
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace seastar;
|
||||
|
||||
using vector_store_client = service::vector_store_client;
|
||||
using vector_store_client_tester = service::vector_store_client_tester;
|
||||
using config = vector_store_client::config;
|
||||
using configuration_exception = exceptions::configuration_exception;
|
||||
using inet_address = seastar::net::inet_address;
|
||||
using function_handler = httpd::function_handler;
|
||||
using http_server = httpd::http_server;
|
||||
using http_server_tester = httpd::http_server_tester;
|
||||
using milliseconds = std::chrono::milliseconds;
|
||||
using operation_type = httpd::operation_type;
|
||||
using port_number = vector_store_client::port_number;
|
||||
using reply = http::reply;
|
||||
using request = http::request;
|
||||
using routes = httpd::routes;
|
||||
using status_type = http::reply::status_type;
|
||||
using url = httpd::url;
|
||||
|
||||
constexpr auto const* LOCALHOST = "127.0.0.1";
|
||||
|
||||
/// Generate an ephemeral port number for listening on localhost.
|
||||
/// After closing this socket, the port should be not listened on for a while.
|
||||
/// This is not guaranteed to be a robust solution, but it should work for most tests.
|
||||
auto generate_unavailable_localhost_port() -> port_number {
|
||||
auto inaddr = net::inet_address(LOCALHOST);
|
||||
auto server = listen(socket_address(inaddr, 0));
|
||||
auto port = server.local_address().port();
|
||||
server.abort_accept();
|
||||
return port;
|
||||
}
|
||||
|
||||
auto listen_on_ephemeral_port(std::unique_ptr<http_server> server) -> future<std::tuple<std::unique_ptr<http_server>, socket_address>> {
|
||||
auto inaddr = net::inet_address(LOCALHOST);
|
||||
auto const addr = socket_address(inaddr, 0);
|
||||
co_await server->listen(addr);
|
||||
auto const& listeners = http_server_tester::listeners(*server);
|
||||
BOOST_CHECK_EQUAL(listeners.size(), 1);
|
||||
co_return std::make_tuple(std::move(server), listeners[0].local_address().port());
|
||||
}
|
||||
|
||||
auto new_http_server(std::function<void(routes& r)> set_routes) -> future<std::tuple<std::unique_ptr<http_server>, socket_address>> {
|
||||
auto server = std::make_unique<http_server>("test_vector_store_client");
|
||||
set_routes(server->_routes);
|
||||
server->set_content_streaming(true);
|
||||
co_return co_await listen_on_ephemeral_port(std::move(server));
|
||||
}
|
||||
|
||||
auto repeat_until(milliseconds timeout, std::function<future<bool>()> func) -> future<bool> {
|
||||
auto begin = lowres_clock::now();
|
||||
while (!co_await func()) {
|
||||
if (lowres_clock::now() - begin > timeout) {
|
||||
co_return false;
|
||||
}
|
||||
}
|
||||
co_return true;
|
||||
}
|
||||
|
||||
auto print_addr(const inet_address& addr) -> sstring {
|
||||
return format("{}", addr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) {
|
||||
{
|
||||
auto cfg = config();
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(vs.is_disabled());
|
||||
BOOST_CHECK(!vs.host());
|
||||
BOOST_CHECK(!vs.port());
|
||||
}
|
||||
{
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.com:6080");
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
BOOST_CHECK_EQUAL(*vs.host(), "good.authority.com");
|
||||
BOOST_CHECK_EQUAL(*vs.port(), 6080);
|
||||
}
|
||||
{
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://bad,authority.com:6080");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
cfg.vector_store_uri.set("bad-schema://authority.com:6080");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
cfg.vector_store_uri.set("http://bad.port.com:a6080");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
cfg.vector_store_uri.set("http://bad.port.com:60806080");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
cfg.vector_store_uri.set("http://bad.format.com:60:80");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
cfg.vector_store_uri.set("http://authority.com:6080/bad/path");
|
||||
BOOST_CHECK_THROW(vector_store_client{cfg}, configuration_exception);
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolving of the hostname is started in start_background_tasks()
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_dns_started) {
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.here:6080");
|
||||
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_REQUIRE(addr);
|
||||
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
|
||||
|
||||
co_await vs.stop();
|
||||
}
|
||||
|
||||
/// Unable to resolve the hostname
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_dns_resolve_failure) {
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.here:6080");
|
||||
|
||||
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_return std::nullopt;
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
BOOST_CHECK(!co_await vector_store_client_tester::resolve_hostname(vs, as));
|
||||
|
||||
co_await vs.stop();
|
||||
}
|
||||
|
||||
/// Resolving of the hostname is repeated after errors
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_dns_resolving_repeated) {
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.here:6080");
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(20));
|
||||
auto count = 0;
|
||||
vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
count++;
|
||||
if (count % 3 != 0) {
|
||||
co_return std::nullopt;
|
||||
}
|
||||
co_return inet_address(format("127.0.0.{}", count));
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
|
||||
co_return co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
}));
|
||||
BOOST_CHECK_EQUAL(count, 3);
|
||||
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_REQUIRE(addr);
|
||||
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.3");
|
||||
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
|
||||
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
|
||||
co_return !co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
}));
|
||||
|
||||
BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future<bool> {
|
||||
co_return co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
}));
|
||||
BOOST_CHECK_EQUAL(count, 6);
|
||||
addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_REQUIRE(addr);
|
||||
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.6");
|
||||
|
||||
co_await vs.stop();
|
||||
}
|
||||
|
||||
/// Minimal interval between DNS refreshes is respected
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_respects_interval) {
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.here:6080");
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
auto count = 0;
|
||||
vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
count++;
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
co_await sleep(std::chrono::milliseconds(20)); // wait for the first DNS refresh
|
||||
|
||||
auto as = abort_source();
|
||||
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_REQUIRE(addr);
|
||||
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
|
||||
BOOST_CHECK_EQUAL(count, 1);
|
||||
count = 0;
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
vector_store_client_tester::trigger_dns_resolver(vs);
|
||||
co_await sleep(std::chrono::milliseconds(100)); // wait for the next DNS refresh
|
||||
|
||||
addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_REQUIRE(addr);
|
||||
BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1");
|
||||
BOOST_CHECK_GE(count, 1);
|
||||
BOOST_CHECK_LE(count, 2);
|
||||
|
||||
co_await vs.stop();
|
||||
}
|
||||
|
||||
/// DNS refresh could be aborted
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) {
|
||||
auto cfg = config();
|
||||
cfg.vector_store_uri.set("http://good.authority.here:6080");
|
||||
auto vs = vector_store_client{cfg};
|
||||
BOOST_CHECK(!vs.is_disabled());
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_await sleep(std::chrono::milliseconds(100));
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
auto timeout = timer([&as]() {
|
||||
as.request_abort();
|
||||
});
|
||||
timeout.arm(std::chrono::milliseconds(10));
|
||||
auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as);
|
||||
BOOST_CHECK(!addr);
|
||||
|
||||
co_await vs.stop();
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(vector_store_client_ann_test_disabled) {
|
||||
co_await do_with_cql_env([](cql_test_env& env) -> future<> {
|
||||
co_await env.execute_cql(R"(
|
||||
create table ks.vs (
|
||||
pk1 tinyint, pk2 tinyint,
|
||||
ck1 tinyint, ck2 tinyint,
|
||||
embedding vector<float, 3>,
|
||||
primary key ((pk1, pk2), ck1, ck2))
|
||||
)");
|
||||
|
||||
auto schema = env.local_db().find_schema("ks", "vs");
|
||||
auto& vs = env.local_qp().vector_store_client();
|
||||
|
||||
auto as = abort_source();
|
||||
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::disabled>(keys.error()));
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_ann_addr_unavailable) {
|
||||
auto cfg = cql_test_config();
|
||||
cfg.db_config->vector_store_uri.set("http://bad.authority.here:6080");
|
||||
co_await do_with_cql_env(
|
||||
[](cql_test_env& env) -> future<> {
|
||||
co_await env.execute_cql(R"(
|
||||
create table ks.vs (
|
||||
pk1 tinyint, pk2 tinyint,
|
||||
ck1 tinyint, ck2 tinyint,
|
||||
embedding vector<float, 3>,
|
||||
primary key ((pk1, pk2), ck1, ck2))
|
||||
)");
|
||||
|
||||
auto schema = env.local_db().find_schema("ks", "vs");
|
||||
auto& vs = env.local_qp().vector_store_client();
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_http_request_retries(vs, 3);
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "bad.authority.here");
|
||||
co_return std::nullopt;
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::addr_unavailable>(keys.error()));
|
||||
},
|
||||
cfg);
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_ann_service_unavailable) {
|
||||
auto cfg = cql_test_config();
|
||||
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port()));
|
||||
co_await do_with_cql_env(
|
||||
[](cql_test_env& env) -> future<> {
|
||||
co_await env.execute_cql(R"(
|
||||
create table ks.vs (
|
||||
pk1 tinyint, pk2 tinyint,
|
||||
ck1 tinyint, ck2 tinyint,
|
||||
embedding vector<float, 3>,
|
||||
primary key ((pk1, pk2), ck1, ck2))
|
||||
)");
|
||||
|
||||
auto schema = env.local_db().find_schema("ks", "vs");
|
||||
auto& vs = env.local_qp().vector_store_client();
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_http_request_retries(vs, 3);
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_unavailable>(keys.error()));
|
||||
},
|
||||
cfg);
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_ann_service_aborted) {
|
||||
auto cfg = cql_test_config();
|
||||
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", generate_unavailable_localhost_port()));
|
||||
co_await do_with_cql_env(
|
||||
[](cql_test_env& env) -> future<> {
|
||||
co_await env.execute_cql(R"(
|
||||
create table ks.vs (
|
||||
pk1 tinyint, pk2 tinyint,
|
||||
ck1 tinyint, ck2 tinyint,
|
||||
embedding vector<float, 3>,
|
||||
primary key ((pk1, pk2), ck1, ck2))
|
||||
)");
|
||||
|
||||
auto schema = env.local_db().find_schema("ks", "vs");
|
||||
auto& vs = env.local_qp().vector_store_client();
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_http_request_retries(vs, 3);
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_await sleep(std::chrono::milliseconds(100));
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
auto as = abort_source();
|
||||
auto timeout = timer([&as]() {
|
||||
as.request_abort();
|
||||
});
|
||||
timeout.arm(std::chrono::milliseconds(10));
|
||||
auto keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::aborted>(keys.error()));
|
||||
},
|
||||
cfg);
|
||||
}
|
||||
|
||||
|
||||
SEASTAR_TEST_CASE(vector_store_client_test_ann_request) {
|
||||
auto ann_replies = make_lw_shared<std::queue<std::tuple<sstring, sstring>>>();
|
||||
auto [server, addr] = co_await new_http_server([ann_replies](routes& r) {
|
||||
auto ann = [ann_replies](std::unique_ptr<request> req, std::unique_ptr<reply> rep) -> future<std::unique_ptr<reply>> {
|
||||
BOOST_REQUIRE(!ann_replies->empty());
|
||||
auto [req_exp, rep_inp] = ann_replies->front();
|
||||
auto const req_inp = co_await util::read_entire_stream_contiguous(*req->content_stream);
|
||||
BOOST_CHECK_EQUAL(req_inp, req_exp);
|
||||
ann_replies->pop();
|
||||
rep->set_status(status_type::ok);
|
||||
rep->write_body("json", rep_inp);
|
||||
co_return rep;
|
||||
};
|
||||
r.add(operation_type::POST, url("/api/v1/indexes/ks/idx").remainder("ann"), new function_handler(ann, "json"));
|
||||
});
|
||||
|
||||
auto cfg = cql_test_config();
|
||||
cfg.db_config->vector_store_uri.set(format("http://good.authority.here:{}", addr.port()));
|
||||
co_await do_with_cql_env(
|
||||
[&ann_replies](cql_test_env& env) -> future<> {
|
||||
co_await env.execute_cql(R"(
|
||||
create table ks.vs (
|
||||
pk1 tinyint, pk2 tinyint,
|
||||
ck1 tinyint, ck2 tinyint,
|
||||
embedding vector<float, 3>,
|
||||
primary key ((pk1, pk2), ck1, ck2))
|
||||
)");
|
||||
|
||||
auto schema = env.local_db().find_schema("ks", "vs");
|
||||
auto& vs = env.local_qp().vector_store_client();
|
||||
|
||||
vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(1000));
|
||||
vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100));
|
||||
vector_store_client_tester::set_http_request_retries(vs, 3);
|
||||
vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future<std::optional<inet_address>> {
|
||||
BOOST_CHECK_EQUAL(host, "good.authority.here");
|
||||
co_return inet_address("127.0.0.1");
|
||||
});
|
||||
|
||||
vs.start_background_tasks();
|
||||
|
||||
// set the wrong idx (wrong endpoint) - service should return 404
|
||||
auto as = abort_source();
|
||||
auto keys = co_await vs.ann("ks", "idx2", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
auto* err = std::get_if<vector_store_client::service_error>(&keys.error());
|
||||
BOOST_CHECK(err != nullptr);
|
||||
BOOST_CHECK_EQUAL(err->status, status_type::not_found);
|
||||
|
||||
// missing primary_keys in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys1":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
|
||||
auto const now = lowres_clock::now();
|
||||
for (;;) {
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
|
||||
// if the service is unavailable or 400, retry, seems http server is not ready yet
|
||||
auto* const unavailable = std::get_if<vector_store_client::service_unavailable>(&keys.error());
|
||||
auto* const service_error = std::get_if<vector_store_client::service_error>(&keys.error());
|
||||
if ((unavailable == nullptr && service_error == nullptr) ||
|
||||
(service_error != nullptr && service_error->status != status_type::bad_request)) {
|
||||
constexpr auto MAX_WAIT = std::chrono::seconds(5);
|
||||
BOOST_REQUIRE(lowres_clock::now() - now < MAX_WAIT);
|
||||
break;
|
||||
}
|
||||
}
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// missing distances in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances1":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// missing pk1 key in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk11":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// missing ck1 key in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck11":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// wrong size of pk2 key in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk1":[5,6],"pk2":[78],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// wrong size of ck2 key in the reply - service should return format error
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[23]},"distances":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(!keys);
|
||||
BOOST_CHECK(std::holds_alternative<vector_store_client::service_reply_format_error>(keys.error()));
|
||||
|
||||
// correct reply - service should return keys
|
||||
ann_replies->emplace(std::make_tuple(R"({"embedding":[0.1,0.2,0.3],"limit":2})",
|
||||
R"({"primary_keys":{"pk1":[5,6],"pk2":[7,8],"ck1":[9,1],"ck2":[2,3]},"distances":[0.1,0.2]})"));
|
||||
keys = co_await vs.ann("ks", "idx", schema, std::vector<float>{0.1, 0.2, 0.3}, 2, as);
|
||||
BOOST_REQUIRE(keys);
|
||||
BOOST_REQUIRE_EQUAL(keys->size(), 2);
|
||||
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).partition.key().explode()), "[05, 07]");
|
||||
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(0).clustering.explode()), "[09, 02]");
|
||||
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).partition.key().explode()), "[06, 08]");
|
||||
BOOST_CHECK_EQUAL(seastar::format("{}", keys->at(1).clustering.explode()), "[01, 03]");
|
||||
},
|
||||
cfg);
|
||||
co_await server->stop();
|
||||
}
|
||||
|
||||
@@ -170,6 +170,7 @@ private:
|
||||
sharded<gms::gossip_address_map> _gossip_address_map;
|
||||
sharded<service::direct_fd_pinger> _fd_pinger;
|
||||
sharded<cdc::cdc_service> _cdc;
|
||||
sharded<service::vector_store_client> _vector_store_client;
|
||||
db::config* _db_config;
|
||||
|
||||
service::raft_group0_client* _group0_client;
|
||||
@@ -704,7 +705,12 @@ private:
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(cql3::prepared_statements_cache::entry_expiry));
|
||||
auth_prep_cache_config.refresh = std::chrono::milliseconds(cfg->permissions_update_interval_in_ms());
|
||||
|
||||
_qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get();
|
||||
_vector_store_client.start(std::ref(*cfg)).get();
|
||||
auto stop_vector_store_client = defer_verbose_shutdown("vector store client", [this] {
|
||||
_vector_store_client.stop().get();
|
||||
});
|
||||
|
||||
_qp.start(std::ref(_proxy), std::move(local_data_dict), std::ref(_mnotifier), std::ref(_vector_store_client), qp_mcfg, std::ref(_cql_config), auth_prep_cache_config, std::ref(_lang_manager)).get();
|
||||
auto stop_qp = defer_verbose_shutdown("query processor", [this] { _qp.stop().get(); });
|
||||
|
||||
_elc_notif.start().get();
|
||||
|
||||
@@ -21,6 +21,7 @@ template<typename T>
|
||||
class sequential_producer {
|
||||
public:
|
||||
using factory_t = std::function<seastar::future<T>()>;
|
||||
using time_point = seastar::shared_future<T>::time_point;
|
||||
|
||||
private:
|
||||
factory_t _factory;
|
||||
@@ -32,11 +33,18 @@ class sequential_producer {
|
||||
clear();
|
||||
}
|
||||
|
||||
seastar::future<T> operator()() {
|
||||
seastar::future<T> operator()(time_point timeout = time_point::max()) {
|
||||
if (_churning.available()) {
|
||||
_churning = _factory();
|
||||
}
|
||||
return _churning.get_future();
|
||||
return _churning.get_future(timeout);
|
||||
}
|
||||
|
||||
seastar::future<T> operator()(seastar::abort_source& as) {
|
||||
if (_churning.available()) {
|
||||
_churning = _factory();
|
||||
}
|
||||
return _churning.get_future(as);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
|
||||
Reference in New Issue
Block a user