diff --git a/configure.py b/configure.py index 2e4e3e5934..69e89bffcb 100755 --- a/configure.py +++ b/configure.py @@ -638,7 +638,8 @@ raft_tests = set([ ]) vector_search_tests = set([ - 'test/vector_search/vector_store_client_test' + 'test/vector_search/vector_store_client_test', + 'test/vector_search/load_balancer_test' ]) wasms = set([ @@ -1652,6 +1653,7 @@ deps['test/raft/discovery_test'] = ['test/raft/discovery_test.cc', 'service/raft/discovery.cc'] + scylla_raft_dependencies deps['test/vector_search/vector_store_client_test'] = ['test/vector_search/vector_store_client_test.cc'] + scylla_tests_dependencies +deps['test/vector_search/load_balancer_test'] = ['test/vector_search/load_balancer_test.cc'] + scylla_tests_dependencies wasm_deps = {} diff --git a/test/vector_search/CMakeLists.txt b/test/vector_search/CMakeLists.txt index 30d4bffb5a..2336103812 100644 --- a/test/vector_search/CMakeLists.txt +++ b/test/vector_search/CMakeLists.txt @@ -1,2 +1,5 @@ add_scylla_test(vector_store_client_test - KIND SEASTAR LIBRARIES vector_search) + LIBRARIES vector_search) + +add_scylla_test(load_balancer_test + LIBRARIES vector_search) diff --git a/test/vector_search/load_balancer_test.cc b/test/vector_search/load_balancer_test.cc new file mode 100644 index 0000000000..7c4b5794d2 --- /dev/null +++ b/test/vector_search/load_balancer_test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "vector_search/load_balancer.hh" +#include +#include + +using namespace seastar; +using namespace vector_search; + +BOOST_AUTO_TEST_CASE(next_returns_nullptr_on_empty_container) { + std::mt19937 seeded_engine(0); + load_balancer lb{std::vector>{}, seeded_engine}; + + BOOST_CHECK(lb.next() == nullptr); +} + +BOOST_AUTO_TEST_CASE(next_returns_all_elements_in_random_order) { + std::mt19937 seeded_engine(0); + std::vector> read; + load_balancer lb{std::vector>{make_lw_shared(1), make_lw_shared(2), make_lw_shared(3)}, seeded_engine}; + + while (auto n = lb.next()) { + read.push_back(n); + } + + BOOST_CHECK_EQUAL(read.size(), 3); + BOOST_CHECK_EQUAL(*read[0], 2); + BOOST_CHECK_EQUAL(*read[1], 3); + BOOST_CHECK_EQUAL(*read[2], 1); +} diff --git a/test/vector_search/vector_store_client_test.cc b/test/vector_search/vector_store_client_test.cc index abf6770638..81aba977ae 100644 --- a/test/vector_search/vector_store_client_test.cc +++ b/test/vector_search/vector_store_client_test.cc @@ -133,8 +133,8 @@ auto create_test_table(cql_test_env& env, const sstring& ks, const sstring& cf) } future<> try_on_loopback_address(std::function(sstring)> func) { - constexpr size_t MAX_ADDR = 20; - for (size_t i = 1; i < MAX_ADDR; i++) { + constexpr size_t MAX_LOCALHOST_ADDR_TO_TRY = 127; + for (size_t i = 1; i < MAX_LOCALHOST_ADDR_TO_TRY; i++) { auto host = fmt::format("127.0.0.{}", i); try { co_await func(std::move(host)); @@ -833,12 +833,17 @@ SEASTAR_TEST_CASE(vector_store_client_high_availability_host_resolved_to_multipl auto& vs = env.local_qp().vector_store_client(); configure(vs).with_dns({{"good.authority.here", std::vector{unavail_s->host(), LOCALHOST}}}); vs.start_background_tasks(); + std::expected keys; - auto keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as.as); + // Because requests are distributed in random order due to load balancing, + // repeat the ANN query until the unavailable server is queried. + BOOST_CHECK(co_await repeat_until(std::chrono::seconds(10), [&]() -> future { + keys = co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as.as); + co_return unavail_s->connections() > 1; + })); - // tried to connect to the unavailable server as it is first in the list of resolved addresses - BOOST_CHECK_EQUAL(unavail_s->connections(), 1); - // successfully got keys from the responding server + // The query is successful because the client falls back to the available server + // when the attempt to connect to the unavailable one fails. BOOST_CHECK(keys); }, cfg) @@ -847,3 +852,33 @@ SEASTAR_TEST_CASE(vector_store_client_high_availability_host_resolved_to_multipl co_await unavail_s->stop(); })); } + +SEASTAR_TEST_CASE(vector_store_client_load_balancing) { + + auto s1 = co_await make_vs_mock_server(); + auto s2 = co_await make_vs_mock_server(s1->port()); + + auto cfg = cql_test_config(); + cfg.db_config->vector_store_primary_uri.set(format("http://good.authority.here:{}", s1->port())); + co_await do_with_cql_env( + [&](cql_test_env& env) -> future<> { + auto as = abort_source_timeout(); + auto schema = co_await create_test_table(env, "ks", "idx"); + auto& vs = env.local_qp().vector_store_client(); + configure(vs).with_dns({{"good.authority.here", std::vector{s1->host(), s2->host()}}}); + vs.start_background_tasks(); + + // Wait until requests are handled by both servers. + // The load balancing algorithm is random, so we send requests in a loop + // until both servers have received at least one, verifying that load is distributed. + BOOST_CHECK(co_await repeat_until(std::chrono::seconds(10), [&]() -> future { + co_await vs.ann("ks", "idx", schema, std::vector{0.1, 0.2, 0.3}, 2, as.as); + co_return s1->requests() > 0 && s2->requests() > 0; + })); + }, + cfg) + .finally(seastar::coroutine::lambda([&s1, &s2] -> future<> { + co_await s1->stop(); + co_await s2->stop(); + })); +} diff --git a/vector_search/load_balancer.hh b/vector_search/load_balancer.hh new file mode 100644 index 0000000000..db02a73140 --- /dev/null +++ b/vector_search/load_balancer.hh @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "seastar/core/shared_ptr.hh" +#include + +namespace vector_search { + +template +class load_balancer { +public: + load_balancer(std::vector> container, RandomNumberEngine& g) + : _container(std::move(container)) + , _g(g) { + } + + seastar::lw_shared_ptr next() { + if (_container.empty()) { + return nullptr; + } + return pop(randomize_index()); + } + +private: + using distribution = std::uniform_int_distribution; + + size_t randomize_index() { + return _dist(_g, distribution::param_type(0, _container.size() - 1)); + } + + seastar::lw_shared_ptr pop(size_t index) { + auto ret = _container[index]; + std::swap(_container[index], _container.back()); + _container.pop_back(); + return ret; + } + + std::vector> _container; + RandomNumberEngine& _g; + distribution _dist; +}; + +} // namespace vector_search diff --git a/vector_search/vector_store_client.cc b/vector_search/vector_store_client.cc index 130671d75c..6e19277377 100644 --- a/vector_search/vector_store_client.cc +++ b/vector_search/vector_store_client.cc @@ -8,6 +8,7 @@ #include "vector_store_client.hh" #include "dns.hh" +#include "load_balancer.hh" #include "cql3/statements/select_statement.hh" #include "cql3/type_json.hh" #include "db/config.hh" @@ -21,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -62,6 +64,8 @@ constexpr auto ANN_RETRIES = 3; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) logging::logger vslogger("vector_store_client"); +static thread_local auto random_engine = std::default_random_engine(std::random_device{}()); + auto parse_port(std::string const& port_txt) -> std::optional { auto port = port_number{}; auto [ptr, ec] = std::from_chars(&*port_txt.begin(), &*port_txt.end(), port); @@ -439,7 +443,8 @@ struct vector_store_client::impl { clients.error())}; } - for (const auto& client : *clients) { + load_balancer lb(std::move(*clients), random_engine); + while (auto client = lb.next()) { auto result = co_await coroutine::as_future(client->make_request( method, path, content, [&resp](http::reply const& reply, input_stream body) -> future<> {