connection_factory: introduce TTL timer

Add a TTL-based timer to connection_factory to automatically refresh
resolved host name addresses when they expire.
This commit is contained in:
Ernest Zaslavsky
2026-01-26 14:37:25 +02:00
parent 66a33619da
commit 6eb7dba352
2 changed files with 42 additions and 21 deletions

View File

@@ -23,12 +23,6 @@ future<shared_ptr<tls::certificate_credentials>> utils::http::system_trust_crede
co_return system_trust_credentials;
}
future<> utils::http::dns_connection_factory::init_addresses() {
auto hent = co_await net::dns::get_host_by_name(_host, net::inet_address::family::INET);
_addr_list = std::move(hent.addr_list);
_logger.debug("Initialized addresses={}", _addr_list);
}
future<> utils::http::dns_connection_factory::init_credentials() {
if (_use_https && !_creds) {
_creds = co_await system_trust_credentials();
@@ -40,14 +34,28 @@ future<> utils::http::dns_connection_factory::init_credentials() {
}
future<net::inet_address> utils::http::dns_connection_factory::get_address() {
if (!_addr_init) [[unlikely]] {
auto units = co_await get_units(_init_semaphore, 1);
if (!_addr_init) {
co_await init_addresses();
_addr_init = true;
}
auto get_addr = [this] -> net::inet_address {
const auto& addresses = _addr_list.value();
return addresses[_addr_pos++ % addresses.size()];
};
if (_addr_list) {
co_return get_addr();
}
co_return _addr_list[_addr_pos++ % _addr_list.size()];
auto units = co_await get_units(_init_semaphore, 1);
if (!_addr_list) {
auto hent = co_await net::dns::get_host_by_name(_host, net::inet_address::family::INET);
_address_ttl = std::ranges::min_element(hent.addr_entries, [](const net::hostent::address_entry& lhs, const net::hostent::address_entry& rhs) {
return lhs.ttl < rhs.ttl;
})->ttl;
if (_address_ttl.count() == 0) {
co_return hent.addr_entries[_addr_pos++ % hent.addr_entries.size()].addr;
}
_addr_list = hent.addr_entries | std::views::transform(&net::hostent::address_entry::addr) | std::ranges::to<std::vector>();
_addr_update_timer.rearm(lowres_clock::now() + _address_ttl);
}
co_return get_addr();
}
future<shared_ptr<tls::certificate_credentials>> utils::http::dns_connection_factory::get_creds() {
@@ -61,8 +69,8 @@ future<shared_ptr<tls::certificate_credentials>> utils::http::dns_connection_fac
co_return _creds;
}
future<connected_socket> utils::http::dns_connection_factory::connect() {
auto socket_addr = socket_address(co_await get_address(), _port);
future<connected_socket> utils::http::dns_connection_factory::connect(net::inet_address address) {
auto socket_addr = socket_address(address, _port);
if (auto creds = co_await get_creds()) {
_logger.debug("Making new HTTPS connection addr={} host={}", socket_addr, _host);
co_return co_await tls::connect(creds, socket_addr, tls::tls_options{.server_name = _host});
@@ -78,7 +86,13 @@ utils::http::dns_connection_factory::dns_connection_factory(std::string host, in
, _port(port)
, _logger(logger)
,_creds(std::move(certs))
,_use_https(use_https) {
, _use_https(use_https)
, _addr_update_timer([this] {
if (auto units = try_get_units(_init_semaphore, 1)) {
_addr_list.reset();
}
}) {
_addr_update_timer.arm(lowres_clock::now());
}
utils::http::dns_connection_factory::dns_connection_factory(std::string uri, logging::logger& logger, shared_ptr<tls::certificate_credentials> certs)
@@ -92,7 +106,13 @@ utils::http::dns_connection_factory::dns_connection_factory(std::string uri, log
{}
future<connected_socket> utils::http::dns_connection_factory::make(abort_source*) {
co_return co_await connect();
auto address = co_await get_address();
co_return co_await connect(address);
}
future<> utils::http::dns_connection_factory::close() {
_addr_update_timer.cancel();
co_await get_units(_init_semaphore, 1);
}
static const char HTTPS[] = "https";

View File

@@ -27,24 +27,25 @@ protected:
int _port;
logging::logger& _logger;
semaphore _init_semaphore{1};
bool _addr_init = false;
bool _creds_init = false;
std::vector<net::inet_address> _addr_list;
std::optional<std::vector<net::inet_address>> _addr_list;
shared_ptr<tls::certificate_credentials> _creds;
uint16_t _addr_pos{0};
bool _use_https;
std::chrono::seconds _address_ttl{0};
timer<lowres_clock> _addr_update_timer;
future<> init_addresses();
future<> init_credentials();
future<net::inet_address> get_address();
future<shared_ptr<tls::certificate_credentials>> get_creds();
future<connected_socket> connect();
future<connected_socket> connect(net::inet_address address);
public:
dns_connection_factory(dns_connection_factory&&);
dns_connection_factory(std::string host, int port, bool use_https, logging::logger& logger, shared_ptr<tls::certificate_credentials> = {});
dns_connection_factory(std::string endpoint_url, logging::logger& logger, shared_ptr<tls::certificate_credentials> = {});
virtual future<connected_socket> make(abort_source*) override;
future<> close() override;
};
// simple URL parser, just enough to handle required aspects for normal endpoint usage