Files
scylladb/vector_search/vector_store_client.cc
Nadav Har'El c56361a6d7 vector_store_client: read and return similarity_scores
The vector store returns for every ANN search, in addition to the keys
of the matching items, two additional vectors - "distances" and
"similarity_cores". The "distances" are raw distance metrics - lower
scores are better matches, while "similarity_scores" are modified
such that higher scores are better matches.

Traditionally, search scores in systems like Cassandra and Open Search
use the "similarity scores" approach (higher is better, results are
returned in decreasing similarity order), so this is the more interesting
vector of the two.

But before this patch, our vector_store_client::ann() inspected
only "distances". But... then, it didn't return even that to the
caller :-)

So in this patch, we:

1. Ignore "distances" and instead look at "similarity scores",
   which is what users really want based on their experience with
   other vector and non-vector search engines.

2. Return the similarity score of each match together with the match.
   We already have this score (the vector store returns it) and we
   can add it to the existing primary_key structure of each result.
   So each result is a "struct primary_key" which has fields partition,
   clustering, and after this patch - similarity.

Existing callers in CQL and Alternator vector search will ignore this
"similarity" field in each result, and not notice it was added.
But in the next patch, we'll allow Alternator's vector search to
return this similarity in each result.

The existing unit tests for vector_store_client.cc mocked vector-store
responses with "distances", without "similarity_scores", so no longer
represent what we actually expect the vector store to do. So this patch
also contains modifications for these tests, to mock and to test
"similarity_scores" - not "distances". The more interesting tests, in
the next patch, use the real vector store and check that we really do
get a "similarity_scores" response from it.

This patch also handles a small corner case for DOT_PRODUCT, which is
the only unbounded similarity function. If the similarity overflows
the 32-bit float, the vector store returns a JSON "null" instead of
a JSON number (since JSON doesn't support infinite numbers). Our
existing vector-store client code errored out when it saw this "null",
which is wrong - the request should be allowed to proceed. So in this
patch when we see a "null" JSON for similarity, we return +Inf.
This is usually correct because the top results really have +Inf, not
-Inf, but if we ask for all items we can reach those with similarity
-Inf and incorrectly assign +Inf to them (we have a test for this case
in the next patch). But this problenm won't happen when Limit is low,
and in any case it's better than aborting the request after it had
already succeeded.

Signed-off-by: Nadav Har'El <nyh@scylladb.com>
2026-05-13 14:19:17 +03:00

472 lines
19 KiB
C++

/*
* Copyright (C) 2025-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1
*/
#include "vector_store_client.hh"
#include "dns.hh"
#include "clients.hh"
#include "uri.hh"
#include "utils.hh"
#include "truststore.hh"
#include "db/config.hh"
#include "exceptions/exceptions.hh"
#include "dht/i_partitioner.hh"
#include "keys/keys.hh"
#include "utils/rjson.hh"
#include "types/json_utils.hh"
#include "schema/schema.hh"
#include <charconv>
#include <exception>
#include <fmt/ranges.h>
#include <regex>
#include <seastar/core/sstring.hh>
#include <seastar/core/metrics.hh>
#include <seastar/coroutine/as_future.hh>
#include <seastar/coroutine/exception.hh>
#include <seastar/http/client.hh>
#include <seastar/http/request.hh>
#include <seastar/net/dns.hh>
#include <seastar/net/inet_address.hh>
#include <seastar/net/socket_defs.hh>
#include <seastar/util/lazy.hh>
#include <seastar/util/short_streams.hh>
namespace {
using namespace std::chrono_literals;
using ann_error = vector_search::vector_store_client::ann_error;
using configuration_exception = exceptions::configuration_exception;
using duration = lowres_clock::duration;
using vs_vector = vector_search::vector_store_client::vs_vector;
using limit = vector_search::vector_store_client::limit;
using host_name = vector_search::vector_store_client::host_name;
using http_path = sstring;
using inet_address = seastar::net::inet_address;
using json_content = sstring;
using milliseconds = std::chrono::milliseconds;
using operation_type = httpd::operation_type;
using port_number = vector_search::vector_store_client::port_number;
using primary_key = vector_search::primary_key;
using primary_keys = vector_search::vector_store_client::primary_keys;
using service_reply_format_error = vector_search::vector_store_client::service_reply_format_error;
using uri = vector_search::uri;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
logging::logger vslogger("vector_store_client");
auto parse_port(std::string const& port_txt) -> std::optional<port_number> {
auto port = port_number{};
auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port);
if (*ptr != '\0' || ec != std::errc{}) {
return std::nullopt;
}
return port;
}
auto parse_service_uri(std::string_view uri_) -> std::optional<uri> {
constexpr auto URI_REGEX = R"(^(http|https):\/\/([a-z0-9._-]+):([0-9]+)$)";
auto const uri_regex = std::regex(URI_REGEX);
auto uri_match = std::smatch{};
auto uri_txt = std::string(uri_);
if (!std::regex_match(uri_txt, uri_match, uri_regex) || uri_match.size() != 4) {
return {};
}
auto schema = uri_match[1].str() == "https" ? uri::schema_type::https : uri::schema_type::http;
auto host = uri_match[2].str();
auto port = parse_port(uri_match[3].str());
if (!port) {
return {};
}
return {{schema, host, *port}};
}
auto get_key_column_value(const rjson::value& item, std::size_t idx, const column_definition& column) -> std::expected<bytes, ann_error> {
auto const& column_name = column.name_as_text();
auto const* keys_obj = rjson::find(item, column_name);
if (keys_obj == nullptr) {
vslogger.error("Vector Store returned invalid JSON: missing key column '{}'", column_name);
return std::unexpected{service_reply_format_error{}};
}
if (!keys_obj->IsArray()) {
vslogger.error("Vector Store returned invalid JSON: key column '{}' is not an array", column_name);
return std::unexpected{service_reply_format_error{}};
}
auto const& keys_arr = keys_obj->GetArray();
if (keys_arr.Size() <= idx) {
vslogger.error("Vector Store returned invalid JSON: key column '{}' array too small", column_name);
return std::unexpected{service_reply_format_error{}};
}
auto const& key = keys_arr[idx];
return from_json_object(*column.type, key);
}
auto pk_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<partition_key, ann_error> {
std::vector<bytes> raw_pk;
for (const column_definition& cdef : schema->partition_key_columns()) {
auto raw_value = get_key_column_value(item, idx, cdef);
if (!raw_value) {
return std::unexpected{raw_value.error()};
}
raw_pk.emplace_back(*raw_value);
}
return partition_key::from_exploded(raw_pk);
}
auto ck_from_json(rjson::value const& item, std::size_t idx, schema_ptr const& schema) -> std::expected<clustering_key_prefix, ann_error> {
if (schema->clustering_key_size() == 0) {
return clustering_key_prefix::make_empty();
}
std::vector<bytes> raw_ck;
for (const column_definition& cdef : schema->clustering_key_columns()) {
auto raw_value = get_key_column_value(item, idx, cdef);
if (!raw_value) {
return std::unexpected{raw_value.error()};
}
raw_ck.emplace_back(*raw_value);
}
return clustering_key_prefix::from_exploded(raw_ck);
}
auto write_ann_json(vs_vector vs_vector, limit limit, const rjson::value& filter) -> json_content {
if (filter.ObjectEmpty()) {
return seastar::format(R"({{"vector":[{}],"limit":{}}})", fmt::join(vs_vector, ","), limit);
}
return seastar::format(R"({{"vector":[{}],"limit":{},"filter":{}}})", fmt::join(vs_vector, ","), limit, rjson::print(filter));
}
auto read_ann_json(rjson::value const& json, schema_ptr const& schema) -> std::expected<primary_keys, ann_error> {
if (!json.HasMember("primary_keys")) {
vslogger.error("Vector Store returned invalid JSON: missing 'primary_keys'");
return std::unexpected{service_reply_format_error{}};
}
auto const& keys_json = json["primary_keys"];
if (!keys_json.IsObject()) {
vslogger.error("Vector Store returned invalid JSON: 'primary_keys' is not an object");
return std::unexpected{service_reply_format_error{}};
}
if (!json.HasMember("similarity_scores")) {
vslogger.error("Vector Store returned invalid JSON: missing 'similarity_scores'");
return std::unexpected{service_reply_format_error{}};
}
auto const& similarity_json = json["similarity_scores"];
if (!similarity_json.IsArray()) {
vslogger.error("Vector Store returned invalid JSON: 'similarity_scores' is not an array");
return std::unexpected{service_reply_format_error{}};
}
auto const& similarity_arr = json["similarity_scores"].GetArray();
// We assume that the similarity_arr, and all the key arrays in keys_json
// have the same length, which is the number of nearest neighbors returned
// by the vector store.
auto size = similarity_arr.Size();
auto keys = primary_keys{};
for (auto idx = 0U; idx < size; ++idx) {
auto pk = pk_from_json(keys_json, idx, schema);
if (!pk) {
return std::unexpected{pk.error()};
}
auto ck = ck_from_json(keys_json, idx, schema);
if (!ck) {
return std::unexpected{ck.error()};
}
auto const& sim_val = similarity_arr[idx];
if (sim_val.IsNumber()) {
keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck, sim_val.GetFloat()});
} else if (sim_val.IsNull()) {
// JSON does not support infinite values, and serde_json serializes
// both +inf and -inf as null. This can only happen with the
// DOT_PRODUCT similarity function when the dot product overflows
// float32 (very high magnitude vectors). Since we can't distinguish
// +inf from -inf here, we use +inf as a conservative approximation
// (the item will be ranked highly, which is appropriate for a very
// large positive dot product; the -inf case is very unlikely in
// practice).
keys.push_back(primary_key{dht::decorate_key(*schema, *pk), *ck, std::numeric_limits<float>::infinity()});
} else {
vslogger.error("Vector Store returned invalid JSON: 'similarity_scores[{}]'={} is not a number", idx, rjson::print(sim_val));
return std::unexpected{service_reply_format_error{}};
}
}
return std::move(keys);
}
bool should_vector_store_service_be_disabled(std::vector<sstring> const& uris) {
return uris.empty() || uris[0].empty();
}
auto parse_uris(std::string_view uris_csv) -> std::vector<uri> {
std::vector<uri> ret;
auto uris = utils::split_comma_separated_list(uris_csv);
if (should_vector_store_service_be_disabled(uris)) {
vslogger.info("Vector Store service URIs are empty, disabling Vector Store service");
return ret;
}
for (const auto& uri : uris) {
auto parsed = parse_service_uri(uri);
if (!parsed) {
throw configuration_exception(fmt::format("Invalid Vector Store service URI: {}", uri));
}
ret.push_back(*parsed);
}
vslogger.info("Vector Store service URIs set to: '{}'", uris_csv);
return ret;
}
auto parse_uris_no_throw(std::string_view uris_csv) -> std::vector<uri> {
try {
return parse_uris(uris_csv);
} catch (const configuration_exception& e) {
vslogger.error("Failed to parse Vector Store service URIs [{}]: {}", uris_csv, e.what());
}
return {};
}
std::vector<sstring> get_hosts(const std::vector<uri>& primary_uris, const std::vector<uri>& secondary_uris) {
std::vector<sstring> ret;
for (const auto& uri : primary_uris) {
ret.push_back(uri.host);
}
for (const auto& uri : secondary_uris) {
ret.push_back(uri.host);
}
return ret;
}
} // namespace
namespace vector_search {
struct vector_store_client::impl {
using invoke_on_others_func = std::function<future<>(std::function<future<>(impl&)>)>;
utils::observer<sstring> _primary_uri_observer;
utils::observer<sstring> _secondary_uri_observer;
std::vector<uri> _primary_uris;
std::vector<uri> _secondary_uris;
dns _dns;
uint64_t _dns_refreshes = 0;
seastar::metrics::metric_groups _metrics;
truststore _truststore;
clients _primary_clients;
clients _secondary_clients;
impl(utils::config_file::named_value<sstring> primary_uris, utils::config_file::named_value<sstring> secondary_uris,
utils::config_file::named_value<uint32_t> unreachable_node_detection_time_in_ms,
utils::config_file::named_value<utils::config_file::string_map> encryption_options, invoke_on_others_func invoke_on_others)
: _primary_uri_observer(primary_uris.observe([this](seastar::sstring uris_csv) {
handle_uris_changed(std::move(uris_csv), _primary_uris, _primary_clients);
}))
, _secondary_uri_observer(secondary_uris.observe([this](seastar::sstring uris_csv) {
handle_uris_changed(std::move(uris_csv), _secondary_uris, _secondary_clients);
}))
, _primary_uris(parse_uris(primary_uris()))
, _secondary_uris(parse_uris(secondary_uris()))
, _dns(
vslogger, get_hosts(_primary_uris, _secondary_uris),
[this](auto const& addrs) -> future<> {
co_await handle_addresses_changed(addrs);
},
_dns_refreshes)
, _truststore(vslogger, encryption_options,
[invoke_on_others = std::move(invoke_on_others)](auto func) {
return invoke_on_others([func = std::move(func)](auto& self) {
return func(self._truststore);
});
})
, _primary_clients(
vslogger,
[this]() {
_dns.trigger_refresh();
},
unreachable_node_detection_time_in_ms, _truststore)
, _secondary_clients(
vslogger,
[this]() {
_dns.trigger_refresh();
},
unreachable_node_detection_time_in_ms, _truststore) {
_metrics.add_group("vector_store", {seastar::metrics::make_gauge("dns_refreshes", seastar::metrics::description("Number of DNS refreshes"), [this] {
return _dns_refreshes;
}).aggregate({seastar::metrics::shard_label})});
}
void handle_uris_changed(seastar::sstring uris_csv, std::vector<uri>& uris, clients& clients) {
clients.clear();
uris = parse_uris_no_throw(uris_csv);
_dns.hosts(get_hosts(_primary_uris, _secondary_uris));
}
future<> handle_addresses_changed(const dns::host_address_map& addrs) {
co_await _primary_clients.handle_changed(_primary_uris, addrs);
co_await _secondary_clients.handle_changed(_secondary_uris, addrs);
}
auto is_disabled() const -> bool {
return _primary_uris.empty() && _secondary_uris.empty();
}
auto get_index_status(keyspace_name keyspace, index_name name, abort_source& as)
-> future<vector_store_client::index_status> {
using index_status = vector_store_client::index_status;
if (is_disabled()) {
co_return index_status::creating;
}
auto path = format("/api/v1/indexes/{}/{}/status", keyspace, name);
auto resp = co_await request(operation_type::GET, std::move(path), std::nullopt, as);
if (!resp || resp->status != status_type::ok) {
co_return index_status::creating;
}
try {
auto json = rjson::parse(response_content_to_sstring(resp->content));
const auto* status = rjson::find(json, "status");
if (!status || !status->IsString()) {
co_return index_status::creating;
}
auto sv = rjson::to_string_view(*status);
if (sv == "SERVING") {
co_return index_status::serving;
}
if (sv == "BOOTSTRAPPING") {
co_return index_status::backfilling;
}
co_return index_status::creating;
} catch (...) {
co_return index_status::creating;
}
}
auto ann(keyspace_name keyspace, index_name name, schema_ptr schema, vs_vector vs_vector, limit limit, const rjson::value& filter, abort_source& as)
-> future<std::expected<primary_keys, ann_error>> {
if (is_disabled()) {
vslogger.error("Disabled Vector Store while calling ann");
co_return std::unexpected{disabled{}};
}
auto path = format("/api/v1/indexes/{}/{}/ann", keyspace, name);
auto content = write_ann_json(std::move(vs_vector), limit, filter);
auto resp = co_await request(operation_type::POST, std::move(path), std::move(content), as);
if (!resp) {
co_return std::unexpected{std::visit(
[](auto&& err) {
return ann_error{err};
},
resp.error())};
}
if (resp->status != status_type::ok) {
auto error_content = response_content_to_sstring(resp->content);
vslogger.error("Vector Store returned error: HTTP status {}: {}", resp->status, error_content);
co_return std::unexpected{service_error{resp->status, std::move(error_content)}};
}
try {
co_return read_ann_json(rjson::parse(std::move(resp->content)), schema);
} catch (const rjson::error& e) {
vslogger.error("Vector Store returned invalid JSON: {}", e.what());
co_return std::unexpected{service_reply_format_error{}};
}
}
future<clients::request_result> request(
seastar::httpd::operation_type method, seastar::sstring path, std::optional<seastar::sstring> content, seastar::abort_source& as) {
auto success_or_aborted = [](const auto& result) {
return result || std::holds_alternative<vector_store_client::aborted>(result.error());
};
if (!_primary_uris.empty()) {
auto result = co_await _primary_clients.request(method, path, content, as);
if (success_or_aborted(result) || _secondary_uris.empty()) {
co_return result;
}
}
if (!_secondary_uris.empty()) {
co_return co_await _secondary_clients.request(method, path, content, as);
}
co_return std::unexpected{service_unavailable{}};
}
};
vector_store_client::vector_store_client(config const& cfg)
: _impl(std::make_unique<impl>(cfg.vector_store_primary_uri, cfg.vector_store_secondary_uri, cfg.vector_store_unreachable_node_detection_time_in_ms,
cfg.vector_store_encryption_options, [this](auto func) {
return container().invoke_on_others([func = std::move(func)](auto& self) {
return func(*self._impl);
});
})) {
}
vector_store_client::~vector_store_client() = default;
void vector_store_client::start_background_tasks() {
_impl->_dns.start_background_tasks();
}
auto vector_store_client::stop() -> future<> {
co_await _impl->_primary_clients.stop();
co_await _impl->_secondary_clients.stop();
co_await _impl->_dns.stop();
co_await _impl->_truststore.stop();
}
auto vector_store_client::is_disabled() const -> bool {
return _impl->is_disabled();
}
auto vector_store_client::get_index_status(keyspace_name keyspace, index_name name, abort_source& as) -> future<index_status> {
return _impl->get_index_status(std::move(keyspace), std::move(name), as);
}
auto vector_store_client::ann(keyspace_name keyspace, index_name name, schema_ptr schema, vs_vector vs_vector, limit limit, const rjson::value& filter,
abort_source& as) -> future<std::expected<primary_keys, ann_error>> {
return _impl->ann(keyspace, name, schema, vs_vector, limit, filter, as);
}
void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) {
vsc._impl->_dns.refresh_interval(interval);
}
void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) {
vsc._impl->_primary_clients.timeout(timeout);
vsc._impl->_secondary_clients.timeout(timeout);
}
void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function<future<std::vector<inet_address>>(sstring const&)> resolver) {
vsc._impl->_dns.resolver(std::move(resolver));
}
void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) {
vsc._impl->_dns.trigger_refresh();
}
auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::vector<inet_address>> {
auto clients = co_await vsc._impl->_primary_clients.get_clients(as);
std::vector<inet_address> ret;
if (!clients) {
co_return ret;
}
for (auto const& c : *clients) {
ret.push_back(c->endpoint().ip);
}
co_return ret;
}
unsigned vector_store_client_tester::truststore_reload_count(vector_store_client& vsc) {
return vsc._impl->_truststore.reload_count();
}
} // namespace vector_search