/* * Copyright (C) 2021-present ScyllaDB */ /* * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 */ #include "generic_server.hh" #include #include #include #include #include #include #include namespace generic_server { class counted_data_source_impl : public data_source_impl { data_source _ds; connection::cpu_concurrency_t& _cpu_concurrency; template future> invoke_with_counting(F&& fun) { if (_cpu_concurrency.stopped) { return fun(); } return futurize_invoke([this] () { _cpu_concurrency.units.return_all(); }).then([fun = std::move(fun)] () { return fun(); }).finally([this] () { _cpu_concurrency.units.adopt(consume_units(_cpu_concurrency.semaphore, 1)); }); }; public: counted_data_source_impl(data_source ds, connection::cpu_concurrency_t& cpu_concurrency) : _ds(std::move(ds)), _cpu_concurrency(cpu_concurrency) {}; virtual ~counted_data_source_impl() = default; virtual future> get() override { return invoke_with_counting([this] {return _ds.get();}); }; virtual future> skip(uint64_t n) override { return invoke_with_counting([this, n] {return _ds.skip(n);}); }; virtual future<> close() override { return _ds.close(); }; }; class counted_data_sink_impl : public data_sink_impl { data_sink _ds; connection::cpu_concurrency_t& _cpu_concurrency; template future<> invoke_with_counting(F&& fun) { if (_cpu_concurrency.stopped) { return fun(); } return futurize_invoke([this] () { _cpu_concurrency.units.return_all(); }).then([fun = std::move(fun)] () mutable { return fun(); }).finally([this] () { _cpu_concurrency.units.adopt(consume_units(_cpu_concurrency.semaphore, 1)); }); }; public: counted_data_sink_impl(data_sink ds, connection::cpu_concurrency_t& cpu_concurrency) : _ds(std::move(ds)), _cpu_concurrency(cpu_concurrency) {}; virtual ~counted_data_sink_impl() = default; virtual temporary_buffer allocate_buffer(size_t size) override { return _ds.allocate_buffer(size); } virtual future<> put(net::packet data) override { return invoke_with_counting([this, data = std::move(data)] () mutable { return _ds.put(std::move(data)); }); } virtual future<> put(std::vector> data) override { return invoke_with_counting([this, data = std::move(data)] () mutable { return _ds.put(std::move(data)); }); } virtual future<> put(temporary_buffer buf) override { return invoke_with_counting([this, buf = std::move(buf)] () mutable { return _ds.put(std::move(buf)); }); } virtual future<> flush() override { return invoke_with_counting([this] (void) mutable { return _ds.flush(); }); } virtual future<> close() override { return _ds.close(); } virtual size_t buffer_size() const noexcept override { return _ds.buffer_size(); } virtual bool can_batch_flushes() const noexcept override { return _ds.can_batch_flushes(); } virtual void on_batch_flush_error() noexcept override { _ds.on_batch_flush_error(); } }; connection::connection(server& server, connected_socket&& fd, named_semaphore& sem, semaphore_units initial_sem_units) : _conns_cpu_concurrency{sem, std::move(initial_sem_units), false} , _server{server} , _fd{std::move(fd)} , _read_buf(data_source(std::make_unique(_fd.input().detach(), _conns_cpu_concurrency))) , _write_buf(output_stream(data_sink(std::make_unique(_fd.output().detach(), _conns_cpu_concurrency)), 8192, output_stream_options{.batch_flushes = true})) , _pending_requests_gate("generic_server::connection") , _hold_server(_server._gate) { ++_server._total_connections; _server._connections_list.push_back(*this); } connection::~connection() { server::connections_list_t::iterator iter = _server._connections_list.iterator_to(*this); for (auto&& gi : _server._gentle_iterators) { if (gi.iter == iter) { gi.iter++; } } _server._connections_list.erase(iter); } connection::execute_under_tenant_type connection::no_tenant() { // return a function that runs the process loop with no scheduling group games return [] (connection_process_loop loop) { return loop(); }; } void connection::switch_tenant(execute_under_tenant_type exec) { _execute_under_current_tenant = std::move(exec); _tenant_switch = true; } future<> server::for_each_gently(noncopyable_function fn) { _gentle_iterators.emplace_front(*this); std::list::iterator gi = _gentle_iterators.begin(); return seastar::do_until([ gi ] { return gi->iter == gi->end; }, [ gi, fn = std::move(fn) ] { fn(*(gi->iter++)); return make_ready_future<>(); } ).finally([ this, gi ] { _gentle_iterators.erase(gi); }); } static bool is_broken_pipe_or_connection_reset(std::exception_ptr ep) { try { std::rethrow_exception(ep); } catch (const std::system_error& e) { auto& code = e.code(); if (code.category() == std::system_category() && (code.value() == EPIPE || code.value() == ECONNRESET)) { return true; } if (code.category() == tls::error_category()) { // Typically ECONNRESET if (code.value() == tls::ERROR_PREMATURE_TERMINATION) { return true; } // If we got an actual EPIPE in push/pull of gnutls, it is _not_ translated // to anything more useful than generic push/pull error. Need to look at // nested exception. if (code.value() == tls::ERROR_PULL || code.value() == tls::ERROR_PUSH) { if (auto p = dynamic_cast(std::addressof(e))) { return is_broken_pipe_or_connection_reset(p->nested_ptr()); } } } return false; } catch (...) {} return false; } future<> connection::process_until_tenant_switch() { _tenant_switch = false; { return do_until([this] { return _read_buf.eof() || _tenant_switch; }, [this] { return process_request(); }); } } future<> connection::process() { return with_gate(_pending_requests_gate, [this] { return do_until([this] { return _read_buf.eof(); }, [this] { return _execute_under_current_tenant([this] { return process_until_tenant_switch(); }); }).then_wrapped([this] (future<> f) { handle_error(std::move(f)); }); }).finally([this] { return _pending_requests_gate.close().then([this] { return _ready_to_respond.handle_exception([] (std::exception_ptr ep) { if (is_broken_pipe_or_connection_reset(ep)) { // expected if another side closes a connection or we're shutting down return; } std::rethrow_exception(ep); }).finally([this] { return _write_buf.close(); }); }); }); } void connection::on_connection_ready() { _conns_cpu_concurrency.stopped = true; _conns_cpu_concurrency.units.return_all(); } future<> connection::shutdown() { try { _fd.shutdown_input(); _fd.shutdown_output(); } catch (...) { } return make_ready_future<>(); } server::server(const sstring& server_name, logging::logger& logger, config cfg) : _server_name{server_name} , _logger{logger} , _gate("generic_server::server") , _conns_cpu_concurrency(cfg.uninitialized_connections_semaphore_cpu_concurrency) , _prev_conns_cpu_concurrency(_conns_cpu_concurrency) , _conns_cpu_concurrency_semaphore(_conns_cpu_concurrency, named_semaphore_exception_factory{"connections cpu concurrency semaphore"}) { _conns_cpu_concurrency.observe([this] (const uint32_t &concurrency) { if (concurrency == _prev_conns_cpu_concurrency) { return; } if (concurrency > _prev_conns_cpu_concurrency) { _conns_cpu_concurrency_semaphore.signal(concurrency - _prev_conns_cpu_concurrency); } else { _conns_cpu_concurrency_semaphore.consume(_prev_conns_cpu_concurrency - concurrency); } _prev_conns_cpu_concurrency = concurrency; }); } server::~server() { } future<> server::stop() { co_await shutdown(); co_await std::exchange(_all_connections_stopped, make_ready_future<>()); } future<> server::shutdown() { if (_gate.is_closed()) { co_return; } _all_connections_stopped = _gate.close(); size_t nr = 0; size_t nr_total = _listeners.size(); _logger.debug("abort accept nr_total={}", nr_total); for (auto&& l : _listeners) { l.abort_accept(); _logger.debug("abort accept {} out of {} done", ++nr, nr_total); } size_t nr_conn = 0; auto nr_conn_total = _connections_list.size(); _logger.debug("shutdown connection nr_total={}", nr_conn_total); co_await coroutine::parallel_for_each(_connections_list, [&] (auto&& c) -> future<> { co_await c.shutdown(); _logger.debug("shutdown connection {} out of {} done", ++nr_conn, nr_conn_total); }); co_await std::move(_listeners_stopped); _abort_source.request_abort(); } future<> server::listen(socket_address addr, std::shared_ptr builder, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, std::function get_shard_instance) { // Note: We are making the assumption that if builder is provided it will be the same for each // invocation, regardless of address etc. In general, only CQL server will call this multiple times, // and if TLS, it will use the same cert set. // Could hold certs in a map and ensure separation, but then we will for all // current uses of this class create duplicate reloadable certs for shard 0, which is // kind of what we wanted to avoid in the first place... if (builder && !_credentials) { if (!get_shard_instance || this_shard_id() == 0) { _credentials = co_await builder->build_reloadable_server_credentials([this, get_shard_instance = std::move(get_shard_instance)](const tls::credentials_builder& b, const std::unordered_set& files, std::exception_ptr ep) -> future<> { if (ep) { _logger.warn("Exception loading {}: {}", files, ep); } else { if (get_shard_instance) { co_await smp::invoke_on_others([&]() { auto& s = get_shard_instance(); if (s._credentials) { b.rebuild(*s._credentials); } }); } _logger.info("Reloaded {}", files); } }); } else { _credentials = builder->build_server_credentials(); } } listen_options lo; lo.reuse_address = true; lo.unix_domain_socket_permissions = unix_domain_socket_permissions; if (is_shard_aware) { lo.lba = server_socket::load_balancing_algorithm::port; } server_socket ss; try { ss = builder ? seastar::tls::listen(_credentials, addr, lo) : seastar::listen(addr, lo); } catch (...) { throw std::runtime_error(format("{} error while listening on {} -> {}", _server_name, addr, std::current_exception())); } _listeners.emplace_back(std::move(ss)); _listeners_stopped = when_all(std::move(_listeners_stopped), do_accepts(_listeners.size() - 1, keepalive, addr)).discard_result(); } future<> server::do_accepts(int which, bool keepalive, socket_address server_addr) { while (!_gate.is_closed()) { seastar::gate::holder holder(_gate); bool shed = false; try { semaphore_units units(_conns_cpu_concurrency_semaphore, 0); if (_conns_cpu_concurrency != std::numeric_limits::max()) { auto u = try_get_units(_conns_cpu_concurrency_semaphore, 1); if (u) { units = std::move(*u); } else { _blocked_connections++; try { units = co_await get_units(_conns_cpu_concurrency_semaphore, 1, std::chrono::minutes(1)); } catch (const semaphore_timed_out&) { shed = true; } } } accept_result cs_sa = co_await _listeners[which].accept(); if (_gate.is_closed()) { break; } auto fd = std::move(cs_sa.connection); auto addr = std::move(cs_sa.remote_address); fd.set_nodelay(true); fd.set_keepalive(keepalive); auto conn = make_connection(server_addr, std::move(fd), std::move(addr), _conns_cpu_concurrency_semaphore, std::move(units)); if (shed) { _shed_connections++; static thread_local logger::rate_limit rate_limit{std::chrono::seconds(10)}; _logger.log(log_level::warn, rate_limit, "too many in-flight connection attempts: {}, connection dropped", _conns_cpu_concurrency_semaphore.waiters()); conn->shutdown().ignore_ready_future(); } // Move the processing into the background. (void)futurize_invoke([this, conn] { return advertise_new_connection(conn); // Notify any listeners about new connection. }).then_wrapped([this, conn] (future<> f) { try { f.get(); } catch (...) { _logger.info("exception while advertising new connection: {}", std::current_exception()); } // Block while monitoring for lifetime/errors. return conn->process().then_wrapped([this, conn] (auto f) { try { f.get(); } catch (...) { auto ep = std::current_exception(); if (!is_broken_pipe_or_connection_reset(ep)) { // some exceptions are expected if another side closes a connection // or we're shutting down _logger.info("exception while processing connection: {}", ep); } } return unadvertise_connection(conn); }); }); } catch (...) { _logger.debug("accept failed: {}", std::current_exception()); } } } future<> server::advertise_new_connection(shared_ptr raw_conn) { return make_ready_future<>(); } future<> server::unadvertise_connection(shared_ptr raw_conn) { return make_ready_future<>(); } }