diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc index 9b1ceb9b09..0f013082ba 100644 --- a/service/vector_store_client.cc +++ b/service/vector_store_client.cc @@ -9,14 +9,35 @@ #include "vector_store_client.hh" #include "db/config.hh" #include "exceptions/exceptions.hh" +#include "utils/sequential_producer.hh" #include +#include #include +#include +#include +#include +#include +#include namespace { using configuration_exception = exceptions::configuration_exception; +using duration = lowres_clock::duration; using host_name = service::vector_store_client::host_name; +using http_client = http::experimental::client; +using inet_address = seastar::net::inet_address; +using milliseconds = std::chrono::milliseconds; using port_number = service::vector_store_client::port_number; +using time_point = lowres_clock::time_point; + +// Wait time before retrying after an exception occurred +constexpr auto EXCEPTION_OCCURED_WAIT = std::chrono::seconds(5); + +// Minimum interval between dns name refreshes +constexpr auto DNS_REFRESH_INTERVAL = std::chrono::seconds(5); + +/// Timeout for waiting for a new client to be available +constexpr auto WAIT_FOR_CLIENT_TIMEOUT = std::chrono::seconds(5); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) logging::logger vslogger("vector_store_client"); @@ -46,18 +67,203 @@ auto parse_service_uri(std::string_view uri) -> std::optional future { + auto result = co_await coroutine::as_future(sleep_abortable(timeout, as)); + if (result.failed()) { + auto err = result.get_exception(); + if (as.abort_requested()) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + +/// Wait for a condition variable to be signaled or timeout. +auto wait_for_signal(condition_variable& cv, time_point timeout) -> future { + auto result = co_await coroutine::as_future(cv.wait(timeout)); + if (result.failed()) { + auto err = result.get_exception(); + if (try_catch(err) != nullptr) { + co_return false; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return true; +} + } // namespace namespace service { struct vector_store_client::impl { + lw_shared_ptr current_client; + std::vector> old_clients; host_name host; port_number port{}; + inet_address addr; + time_point last_dns_refresh; gate tasks_gate; + condition_variable refresh_cv; + condition_variable refresh_client_cv; + abort_source abort_refresh; + milliseconds dns_refresh_interval = DNS_REFRESH_INTERVAL; + milliseconds wait_for_client_timeout = WAIT_FOR_CLIENT_TIMEOUT; + std::function>(sstring const&)> dns_resolver; + sequential_producer> client_producer; impl(host_name host_, port_number port_) : host(std::move(host_)) - , port(port_) { + , port(port_) + , dns_resolver([](auto const& host) -> future> { + auto addr = co_await coroutine::as_future(net::dns::resolve_name(host)); + if (addr.failed()) { + auto err = addr.get_exception(); + if (try_catch(err) != nullptr) { + co_return std::nullopt; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return co_await std::move(addr); + }) + , client_producer([&]() -> future> { + trigger_dns_refresh(); + co_await wait_for_signal(refresh_client_cv, lowres_clock::now() + wait_for_client_timeout); + co_return current_client; + }) { + } + + /// Refresh the http client with a new address resolved from the DNS name. + /// If the DNS resolution fails, the current client is set to nullptr. + /// If the address is the same as the current one, do nothing. + /// Old clients are saved for later cleanup in a specific task. + auto refresh_addr() -> future<> { + auto new_addr = co_await dns_resolver(host); + if (!new_addr) { + current_client = nullptr; + co_return; + } + + // Check if the new address is the same as the current one + if (current_client && *new_addr == addr) { + co_return; + } + + addr = *new_addr; + old_clients.emplace_back(current_client); + current_client = make_lw_shared(socket_address(addr, port)); + } + + /// A task for refreshing the vector store http client. + auto refresh_addr_task() -> future<> { + for (;;) { + auto exception_occured = false; + try { + if (abort_refresh.abort_requested()) { + break; + } + + // Do not refresh the service address too often + auto now = lowres_clock::now(); + auto current_duration = now - last_dns_refresh; + if (current_duration > dns_refresh_interval) { + last_dns_refresh = now; + co_await refresh_addr(); + } else { + // Wait till the end of the refreshing interval + if (co_await wait_for_timeout(dns_refresh_interval - current_duration, abort_refresh)) { + continue; + } + // If the wait was aborted, we stop refreshing + break; + } + + if (abort_refresh.abort_requested()) { + break; + } + + // new client is available + refresh_client_cv.broadcast(); + + co_await cleanup_old_clients(); + + co_await refresh_cv.when(); + } catch (const std::exception& e) { + vslogger.error("Vector Store Client refresh task failed: {}", e.what()); + exception_occured = true; + } catch (...) { + vslogger.error("Vector Store Client refresh task failed with unknown exception"); + exception_occured = true; + } + if (exception_occured) { + // If an exception occurred, we wait for the next signal to refresh the address + co_await wait_for_timeout(EXCEPTION_OCCURED_WAIT, abort_refresh); + } + } + + co_await cleanup_old_clients(); + co_await cleanup_current_client(); + } + + /// Request a DNS refresh in the specific task. + void trigger_dns_refresh() { + refresh_cv.signal(); + } + + /// Cleanup current client + auto cleanup_current_client() -> future<> { + if (current_client) { + co_await current_client->close(); + } + current_client = nullptr; + } + + /// Cleanup old clients that are no longer used. + auto cleanup_old_clients() -> future<> { + // iterate over old clients and close them. There is a co_await in the loop + // so we need to use [] accessor and copying clients to avoid dangling references of iterators. + // NOLINTNEXTLINE(modernize-loop-convert) + for (auto it = 0U; it < old_clients.size(); ++it) { + auto& client = old_clients[it]; + if (client && client.owned()) { + auto client_cloned = client; + co_await client_cloned->close(); + client_cloned = nullptr; + } + } + std::erase_if(old_clients, [](auto const& client) { + return !client; + }); + } + + struct get_client_response { + lw_shared_ptr client; ///< The http client. + host_name host; ///< The host name for the vector-store service. + }; + + using get_client_error = std::variant; + + /// Get the current http client or wait for a new one to be available. + auto get_client(abort_source& as) -> future> { + if (current_client) { + co_return get_client_response{.client = current_client, .host = host}; + } + + auto current_client = co_await coroutine::as_future(client_producer(as)); + + if (current_client.failed()) { + auto err = current_client.get_exception(); + if (as.abort_requested()) { + co_return std::unexpected{aborted{}}; + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + auto client = co_await std::move(current_client); + if (!client) { + co_return std::unexpected{addr_unavailable{}}; + } + co_return get_client_response{.client = client, .host = host}; } }; @@ -84,12 +290,22 @@ void vector_store_client::start_background_tasks() { if (is_disabled()) { return; } + + /// start the background task to refresh the service address + (void)try_with_gate(_impl->tasks_gate, [this] { + return _impl->refresh_addr_task(); + }).handle_exception([](std::exception_ptr eptr) { + on_internal_error_noexcept(vslogger, format("The Vector Store Client refresh task failed: {}", eptr)); + }); } auto vector_store_client::stop() -> future<> { if (is_disabled()) { co_return; } + + _impl->abort_refresh.request_abort(); + _impl->refresh_cv.signal(); co_await _impl->tasks_gate.close(); } @@ -107,5 +323,44 @@ auto vector_store_client::port() const -> std::expected { return {_impl->port}; } +void vector_store_client_tester::set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_refresh_interval on a disabled vector store client"); + } + vsc._impl->dns_refresh_interval = interval; +} + +void vector_store_client_tester::set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set wait_for_client_timeout on a disabled vector store client"); + } + vsc._impl->wait_for_client_timeout = timeout; +} + +void vector_store_client_tester::set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot set dns_resolver on a disabled vector store client"); + } + vsc._impl->dns_resolver = std::move(resolver); +} + +void vector_store_client_tester::trigger_dns_resolver(vector_store_client& vsc) { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot trigger a dns resolver on a disabled vector store client"); + } + vsc._impl->trigger_dns_refresh(); +} + +auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abort_source& as) -> future> { + if (vsc.is_disabled()) { + on_internal_error(vslogger, "Cannot check hostname resolving on a disabled vector store client"); + } + auto client_host = co_await vsc._impl->get_client(as); + if (!client_host) { + co_return std::nullopt; + } + co_return vsc._impl->addr; +} + } // namespace service diff --git a/service/vector_store_client.hh b/service/vector_store_client.hh index 00fcae2a1e..f9aa5cdb09 100644 --- a/service/vector_store_client.hh +++ b/service/vector_store_client.hh @@ -17,6 +17,10 @@ namespace db { class config; } +namespace seastar::net { +class inet_address; +} + namespace service { /// A client with the vector-store service. @@ -32,6 +36,12 @@ public: /// The vector_store_client service is disabled. struct disabled {}; + /// The operation was aborted. + struct aborted {}; + + /// The vector-store addr is unavailable (not possible to get an addr from the dns service). + struct addr_unavailable {}; + explicit vector_store_client(config const& cfg); ~vector_store_client(); @@ -57,7 +67,13 @@ private: }; /// A tester for the vector_store_client, used for testing purposes. -struct vector_store_client_tester {}; +struct vector_store_client_tester { + static void set_dns_refresh_interval(vector_store_client& vsc, std::chrono::milliseconds interval); + static void set_wait_for_client_timeout(vector_store_client& vsc, std::chrono::milliseconds timeout); + static void set_dns_resolver(vector_store_client& vsc, std::function>(sstring const&)> resolver); + static void trigger_dns_resolver(vector_store_client& vsc); + static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future>; +}; } // namespace service diff --git a/test/boost/vector_store_client_test.cc b/test/boost/vector_store_client_test.cc index abb3355485..58f64803ba 100644 --- a/test/boost/vector_store_client_test.cc +++ b/test/boost/vector_store_client_test.cc @@ -26,6 +26,23 @@ using vector_store_client = service::vector_store_client; using vector_store_client_tester = service::vector_store_client_tester; using config = vector_store_client::config; using configuration_exception = exceptions::configuration_exception; +using inet_address = seastar::net::inet_address; +using milliseconds = std::chrono::milliseconds; +using port_number = vector_store_client::port_number; + +auto repeat_until(milliseconds timeout, std::function()> func) -> future { + auto begin = lowres_clock::now(); + while (!co_await func()) { + if (lowres_clock::now() - begin > timeout) { + co_return false; + } + } + co_return true; +} + +auto print_addr(const inet_address& addr) -> sstring { + return format("{}", addr); +} } // namespace @@ -62,3 +79,168 @@ BOOST_AUTO_TEST_CASE(vector_store_client_test_ctor) { } } +/// Resolving of the hostname is started in start_background_tasks() +SEASTAR_TEST_CASE(vector_store_client_test_dns_started) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + + co_await vs.stop(); +} + +/// Unable to resolve the hostname +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolve_failure) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + + + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(2000)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_return std::nullopt; + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(!co_await vector_store_client_tester::resolve_hostname(vs, as)); + + co_await vs.stop(); +} + +/// Resolving of the hostname is repeated after errors +SEASTAR_TEST_CASE(vector_store_client_test_dns_resolving_repeated) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(20)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + if (count % 3 != 0) { + co_return std::nullopt; + } + co_return inet_address(format("127.0.0.{}", count)); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 3); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.3"); + + vector_store_client_tester::trigger_dns_resolver(vs); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return !co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + + BOOST_CHECK(co_await repeat_until(std::chrono::milliseconds(1000), [&vs, &as]() -> future { + co_return co_await vector_store_client_tester::resolve_hostname(vs, as); + })); + BOOST_CHECK_EQUAL(count, 6); + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.6"); + + co_await vs.stop(); +} + +/// Minimal interval between DNS refreshes is respected +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_respects_interval) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + auto count = 0; + vector_store_client_tester::set_dns_resolver(vs, [&count](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + count++; + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + co_await sleep(std::chrono::milliseconds(20)); // wait for the first DNS refresh + + auto as = abort_source(); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_EQUAL(count, 1); + count = 0; + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + vector_store_client_tester::trigger_dns_resolver(vs); + co_await sleep(std::chrono::milliseconds(100)); // wait for the next DNS refresh + + addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_REQUIRE(addr); + BOOST_CHECK_EQUAL(print_addr(*addr), "127.0.0.1"); + BOOST_CHECK_GE(count, 1); + BOOST_CHECK_LE(count, 2); + + co_await vs.stop(); +} + +/// DNS refresh could be aborted +SEASTAR_TEST_CASE(vector_store_client_test_dns_refresh_aborted) { + auto cfg = config(); + cfg.vector_store_uri.set("http://good.authority.here:6080"); + auto vs = vector_store_client{cfg}; + BOOST_CHECK(!vs.is_disabled()); + + vector_store_client_tester::set_dns_refresh_interval(vs, std::chrono::milliseconds(10)); + vector_store_client_tester::set_wait_for_client_timeout(vs, std::chrono::milliseconds(100)); + vector_store_client_tester::set_dns_resolver(vs, [](auto const& host) -> future> { + BOOST_CHECK_EQUAL(host, "good.authority.here"); + co_await sleep(std::chrono::milliseconds(100)); + co_return inet_address("127.0.0.1"); + }); + + vs.start_background_tasks(); + + auto as = abort_source(); + auto timeout = timer([&as]() { + as.request_abort(); + }); + timeout.arm(std::chrono::milliseconds(10)); + auto addr = co_await vector_store_client_tester::resolve_hostname(vs, as); + BOOST_CHECK(!addr); + + co_await vs.stop(); +} +