diff --git a/test/vector_search/utils.hh b/test/vector_search/utils.hh index 7ba8be8d0b..4839f26344 100644 --- a/test/vector_search/utils.hh +++ b/test/vector_search/utils.hh @@ -155,11 +155,6 @@ struct unreachable_socket { conn.shutdown_output(); co_await conn.wait_input_shutdown(); } - // There is currently no effective way to abort an ongoing connect in Seastar. - // Timing out connect by with_timeout, remains pending coroutine in the reactor. - // To prevent resource leaks, we close the unreachable socket and sleep, - // allowing the pending connect coroutines to fail and release their resources. - co_await seastar::sleep(3s); } }; diff --git a/vector_search/client.cc b/vector_search/client.cc index 7e44aa044b..5228637b6c 100644 --- a/vector_search/client.cc +++ b/vector_search/client.cc @@ -8,6 +8,7 @@ #include "client.hh" #include "utils.hh" +#include "utils/composite_abort_source.hh" #include "utils/exceptions.hh" #include "utils/exponential_backoff_retry.hh" #include "utils/rjson.hh" @@ -18,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -33,6 +36,35 @@ bool is_ip_address(const sstring& host) { return net::inet_address::parse_numerical(host).has_value(); } +future connect_with_as(socket_address addr, shared_ptr creds, sstring host, abort_source& as) { + as.check(); + auto sock = make_socket(); + auto sub = as.subscribe([&sock]() noexcept { + sock.shutdown(); + }); + auto f = co_await coroutine::as_future(sock.connect(addr)); + if (as.abort_requested()) { + f.ignore_ready_future(); + throw abort_requested_exception(); + } + + auto cs = co_await coroutine::try_future(std::move(f)); + if (creds) { + tls::tls_options opts; + if (!is_ip_address(host)) { + opts.server_name = host; + } + auto tls_cs = co_await tls::wrap_client(creds, std::move(cs), std::move(opts)); + co_return tls_cs; + } + co_return cs; +} + + +bool is_request_aborted(std::exception_ptr& err) { + return try_catch(err) != nullptr; +} + class client_connection_factory : public http::experimental::connection_factory { client::endpoint_type _endpoint; shared_ptr _creds; @@ -46,31 +78,35 @@ public: } future make([[maybe_unused]] abort_source* as) override { - auto deadline = std::chrono::steady_clock::now() + timeout(); - auto socket = co_await with_timeout(deadline, connect()); + auto t = timeout(); + auto socket = co_await connect(t, as); socket.set_nodelay(true); - socket.set_keepalive_parameters(get_keepalive_parameters(timeout())); + socket.set_keepalive_parameters(get_keepalive_parameters(t)); socket.set_keepalive(true); - unsigned int timeout_ms = timeout().count(); + unsigned int timeout_ms = t.count(); socket.set_sockopt(IPPROTO_TCP, TCP_USER_TIMEOUT, &timeout_ms, sizeof(timeout_ms)); co_return socket; } private: - future connect() { - auto addr = socket_address(_endpoint.ip, _endpoint.port); - if (_creds) { - tls::tls_options opts; - if (!is_ip_address(_endpoint.host)) { - opts.server_name = _endpoint.host; - } - auto socket = co_await tls::connect(_creds, addr, std::move(opts)); - // tls::connect() only performs the TCP handshake — the TLS handshake is deferred until the first I/O operation. - // Force the TLS handshake to happen here so that the connection timeout applies to it. - co_await tls::check_session_is_resumed(socket); - co_return socket; + future connect(std::chrono::milliseconds timeout, abort_source* as) { + abort_on_expiry timeout_as(seastar::lowres_clock::now() + timeout); + utils::composite_abort_source composite_as; + composite_as.add(timeout_as.abort_source()); + if (as) { + composite_as.add(*as); } - co_return co_await seastar::connect(addr, {}, transport::TCP); + auto f = co_await coroutine::as_future( + connect_with_as(socket_address(_endpoint.ip, _endpoint.port), _creds, _endpoint.host, composite_as.abort_source())); + if (f.failed()) { + auto err = f.get_exception(); + // When the connection abort was triggered by our own deadline rethrow as timed_out_error. + if (is_request_aborted(err) && timeout_as.abort_source().abort_requested()) { + co_await coroutine::return_exception(timed_out_error{}); + } + co_await coroutine::return_exception_ptr(std::move(err)); + } + co_return co_await std::move(f); } std::chrono::milliseconds timeout() const { @@ -93,10 +129,6 @@ bool is_server_problem(std::exception_ptr& err) { return is_server_unavailable(err) || try_catch(err) != nullptr || try_catch(err) != nullptr; } -bool is_request_aborted(std::exception_ptr& err) { - return try_catch(err) != nullptr; -} - future map_err(std::exception_ptr& err) { if (is_server_problem(err)) { co_return service_unavailable_error{};