diff --git a/test/vector_search/utils.hh b/test/vector_search/utils.hh index 4c41eb2ae8..fcdcb3ab40 100644 --- a/test/vector_search/utils.hh +++ b/test/vector_search/utils.hh @@ -154,11 +154,6 @@ struct unreachable_socket { conn.shutdown_output(); co_await conn.wait_input_shutdown(); } - // There is currently no effective way to abort an ongoing connect in Seastar. - // Timing out connect by with_timeout, remains pending coroutine in the reactor. - // To prevent resource leaks, we close the unreachable socket and sleep, - // allowing the pending connect coroutines to fail and release their resources. - co_await seastar::sleep(3s); } }; diff --git a/test/vector_search/vector_store_client_test.cc b/test/vector_search/vector_store_client_test.cc index b0c5a6f3f3..7c8cf8a5a9 100644 --- a/test/vector_search/vector_store_client_test.cc +++ b/test/vector_search/vector_store_client_test.cc @@ -920,18 +920,20 @@ SEASTAR_TEST_CASE(vector_store_client_updates_backoff_max_time_from_read_request // Verify backoff timing between status check connections. // Skip the first connection (ANN request) and analyze status check intervals. + // Allow small tolerance for timer imprecision: measured intervals can be slightly shorter than the programmed sleep duration. + constexpr auto TIMER_TOLERANCE = std::chrono::milliseconds(10); auto duration_between_1st_and_2nd_status_check = std::chrono::duration_cast( unavail_s->connections().at(2).timestamp - unavail_s->connections().at(1).timestamp); - BOOST_CHECK_GE(duration_between_1st_and_2nd_status_check, std::chrono::milliseconds(100)); + BOOST_CHECK_GE(duration_between_1st_and_2nd_status_check, std::chrono::milliseconds(100) - TIMER_TOLERANCE); BOOST_CHECK_LT(duration_between_1st_and_2nd_status_check, std::chrono::milliseconds(200)); auto duration_between_2nd_and_3rd_status_check = std::chrono::duration_cast( unavail_s->connections().at(3).timestamp - unavail_s->connections().at(2).timestamp); // Max backoff time reached at 200ms, so subsequent status checks use fixed 200ms intervals. - BOOST_CHECK_GE(duration_between_2nd_and_3rd_status_check, std::chrono::milliseconds(200)); // 200ms = 100ms * 2 + BOOST_CHECK_GE(duration_between_2nd_and_3rd_status_check, std::chrono::milliseconds(200) - TIMER_TOLERANCE); // 200ms = 100ms * 2 BOOST_CHECK_LT(duration_between_2nd_and_3rd_status_check, std::chrono::milliseconds(400)); auto duration_between_3rd_and_4th_status_check = std::chrono::duration_cast( unavail_s->connections().at(4).timestamp - unavail_s->connections().at(3).timestamp); - BOOST_CHECK_GE(duration_between_3rd_and_4th_status_check, std::chrono::milliseconds(200)); + BOOST_CHECK_GE(duration_between_3rd_and_4th_status_check, std::chrono::milliseconds(200) - TIMER_TOLERANCE); BOOST_CHECK_LT(duration_between_3rd_and_4th_status_check, std::chrono::milliseconds(400)); }, cfg) diff --git a/vector_search/client.cc b/vector_search/client.cc index 1bde6327e8..2c23d1f08b 100644 --- a/vector_search/client.cc +++ b/vector_search/client.cc @@ -8,6 +8,7 @@ #include "client.hh" #include "utils.hh" +#include "utils/composite_abort_source.hh" #include "utils/exceptions.hh" #include "utils/exponential_backoff_retry.hh" #include "utils/rjson.hh" @@ -18,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +30,39 @@ using namespace std::chrono_literals; namespace vector_search { namespace { +bool is_ip_address(const sstring& host) { + return net::inet_address::parse_numerical(host).has_value(); +} + +future connect_with_as(socket_address addr, shared_ptr creds, sstring host, abort_source& as) { + as.check(); + auto sock = make_socket(); + auto sub = as.subscribe([&sock]() noexcept { + sock.shutdown(); + }); + auto f = co_await coroutine::as_future(sock.connect(addr)); + if (as.abort_requested()) { + f.ignore_ready_future(); + throw abort_requested_exception(); + } + + auto cs = co_await std::move(f); + if (creds) { + tls::tls_options opts; + if (!is_ip_address(host)) { + opts.server_name = host; + } + auto tls_cs = co_await tls::wrap_client(creds, std::move(cs), std::move(opts)); + co_return tls_cs; + } + co_return cs; +} + + +bool is_request_aborted(std::exception_ptr& err) { + return try_catch(err) != nullptr; +} + class client_connection_factory : public http::experimental::connection_factory { client::endpoint_type _endpoint; shared_ptr _creds; @@ -41,27 +76,35 @@ public: } future make([[maybe_unused]] abort_source* as) override { - auto deadline = std::chrono::steady_clock::now() + timeout(); - auto socket = co_await with_timeout(deadline, connect()); + auto t = timeout(); + auto socket = co_await connect(t, as); socket.set_nodelay(true); - socket.set_keepalive_parameters(get_keepalive_parameters(timeout())); + socket.set_keepalive_parameters(get_keepalive_parameters(t)); socket.set_keepalive(true); - unsigned int timeout_ms = timeout().count(); + unsigned int timeout_ms = t.count(); socket.set_sockopt(IPPROTO_TCP, TCP_USER_TIMEOUT, &timeout_ms, sizeof(timeout_ms)); co_return socket; } private: - future connect() { - auto addr = socket_address(_endpoint.ip, _endpoint.port); - if (_creds) { - auto socket = co_await tls::connect(_creds, addr, tls::tls_options{.server_name = _endpoint.host}); - // tls::connect() only performs the TCP handshake — the TLS handshake is deferred until the first I/O operation. - // Force the TLS handshake to happen here so that the connection timeout applies to it. - co_await tls::check_session_is_resumed(socket); - co_return socket; + future connect(std::chrono::milliseconds timeout, abort_source* as) { + abort_on_expiry timeout_as(seastar::lowres_clock::now() + timeout); + utils::composite_abort_source composite_as; + composite_as.add(timeout_as.abort_source()); + if (as) { + composite_as.add(*as); } - co_return co_await seastar::connect(addr, {}, transport::TCP); + auto f = co_await coroutine::as_future( + connect_with_as(socket_address(_endpoint.ip, _endpoint.port), _creds, _endpoint.host, composite_as.abort_source())); + if (f.failed()) { + auto err = f.get_exception(); + // When the connection abort was triggered by our own deadline rethrow as timed_out_error. + if (is_request_aborted(err) && timeout_as.abort_source().abort_requested()) { + co_await coroutine::return_exception(timed_out_error{}); + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return co_await std::move(f); } std::chrono::milliseconds timeout() const { @@ -84,10 +127,6 @@ bool is_server_problem(std::exception_ptr& err) { return is_server_unavailable(err) || try_catch(err) != nullptr || try_catch(err) != nullptr; } -bool is_request_aborted(std::exception_ptr& err) { - return try_catch(err) != nullptr; -} - future map_err(std::exception_ptr& err) { if (is_server_problem(err)) { co_return service_unavailable_error{};