From f41dac2a3ae0b631b7381072dfc7d7fd03b3f1fb Mon Sep 17 00:00:00 2001 From: Nadav Har'El Date: Wed, 10 Mar 2021 00:25:25 +0200 Subject: [PATCH] alternator: avoid large contiguous allocation for request body Alternator request sizes can be up to 16 MB, but the current implementation had the Seastar HTTP server read the entire request as a contiguous string, and then processed it. We can't avoid reading the entire request up-front - we want to verify its integrity before doing any additional processing on it. But there is no reason why the entire request needs to be stored in one big *contiguous* allocation. This always a bad idea. We should use a non- contiguous buffer, and that's the goal of this patch. We use a new Seastar HTTPD feature where we can ask for an input stream, instead of a string, for the request's body. We then begin the request handling by reading lthe content of this stream into a vector> (which we alias "chunked_content"). We then use this non-contiguous buffer to verify the request's signature and if successful - parse the request JSON and finally execute it. Beyond avoiding contiguous allocations, another benefit of this patch is that while parsing a long request composed of chunks, we free each chunk as soon as its parsing completed. This reduces the peak amount of memory used by the query - we no longer need to store both unparsed and parsed versions of the request at the same time. Although we already had tests with requests of different lengths, most of them were short enough to only have one chunk, and only a few had 2 or 3 chunks. So we also add a test which makes a much longer request (a BatchWriteItem with large items), which in my experiment had 17 chunks. The goal of this test is to verify that the new signature and JSON parsing code which needs to cross chunk boundaries work as expected. Fixes #7213. Signed-off-by: Nadav Har'El Message-Id: <20210309222525.1628234-1-nyh@scylladb.com> --- alternator/auth.cc | 10 ++- alternator/auth.hh | 2 +- alternator/executor.cc | 16 ----- alternator/executor.hh | 2 - alternator/server.cc | 131 ++++++++++++++++++++++++---------- alternator/server.hh | 12 +++- test/alternator/test_batch.py | 16 +++++ utils/rjson.cc | 100 ++++++++++++++++++++++++-- utils/rjson.hh | 12 ++++ 9 files changed, 237 insertions(+), 64 deletions(-) diff --git a/alternator/auth.cc b/alternator/auth.cc index 75f6e05afb..623d829d94 100644 --- a/alternator/auth.cc +++ b/alternator/auth.cc @@ -62,6 +62,14 @@ static std::string apply_sha256(std::string_view msg) { return to_hex(hasher.finalize()); } +static std::string apply_sha256(const std::vector>& msg) { + sha256_hasher hasher; + for (const temporary_buffer& buf : msg) { + hasher.update(buf.get(), buf.size()); + } + return to_hex(hasher.finalize()); +} + static std::string format_time_point(db_clock::time_point tp) { time_t time_point_repr = db_clock::to_time_t(tp); std::string time_point_str; @@ -91,7 +99,7 @@ void check_expiry(std::string_view signature_date) { std::string get_signature(std::string_view access_key_id, std::string_view secret_access_key, std::string_view host, std::string_view method, std::string_view orig_datestamp, std::string_view signed_headers_str, const std::map& signed_headers_map, - std::string_view body_content, std::string_view region, std::string_view service, std::string_view query_string) { + const std::vector>& body_content, std::string_view region, std::string_view service, std::string_view query_string) { auto amz_date_it = signed_headers_map.find("x-amz-date"); if (amz_date_it == signed_headers_map.end()) { throw api_error::invalid_signature("X-Amz-Date header is mandatory for signature verification"); diff --git a/alternator/auth.hh b/alternator/auth.hh index d3fa02c403..25fd506368 100644 --- a/alternator/auth.hh +++ b/alternator/auth.hh @@ -39,7 +39,7 @@ using key_cache = utils::loading_cache; std::string get_signature(std::string_view access_key_id, std::string_view secret_access_key, std::string_view host, std::string_view method, std::string_view orig_datestamp, std::string_view signed_headers_str, const std::map& signed_headers_map, - std::string_view body_content, std::string_view region, std::string_view service, std::string_view query_string); + const std::vector>& body_content, std::string_view region, std::string_view service, std::string_view query_string); future get_key_from_roles(cql3::query_processor& qp, std::string username); diff --git a/alternator/executor.cc b/alternator/executor.cc index fe6f0bf596..236f1b77c1 100644 --- a/alternator/executor.cc +++ b/alternator/executor.cc @@ -3909,22 +3909,6 @@ future<> executor::create_keyspace(std::string_view keyspace_name) { }); } -static tracing::trace_state_ptr create_tracing_session() { - tracing::trace_state_props_set props; - props.set(); - return tracing::tracing::get_local_tracing_instance().create_session(tracing::trace_type::QUERY, props); -} - -tracing::trace_state_ptr executor::maybe_trace_query(client_state& client_state, sstring_view op, sstring_view query) { - tracing::trace_state_ptr trace_state; - if (tracing::tracing::get_local_tracing_instance().trace_next_query()) { - trace_state = create_tracing_session(); - tracing::add_query(trace_state, query); - tracing::begin(trace_state, format("Alternator {}", op), client_state.get_client_address()); - } - return trace_state; -} - future<> executor::start() { // Currently, nothing to do on initialization. We delay the keyspace // creation (create_keyspace()) until a table is actually created. diff --git a/alternator/executor.hh b/alternator/executor.hh index 6a89ded936..bbe84910e6 100644 --- a/alternator/executor.hh +++ b/alternator/executor.hh @@ -187,8 +187,6 @@ public: future<> create_keyspace(std::string_view keyspace_name); - static tracing::trace_state_ptr maybe_trace_query(client_state& client_state, sstring_view op, sstring_view query); - static sstring table_name(const schema&); static db::timeout_clock::time_point default_timeout(); static void set_default_timeout(db::timeout_clock::duration timeout); diff --git a/alternator/server.cc b/alternator/server.cc index 30b65f0cbe..c3146da4da 100644 --- a/alternator/server.cc +++ b/alternator/server.cc @@ -22,6 +22,8 @@ #include "alternator/server.hh" #include "log.hh" #include +#include +#include #include #include "seastarx.hh" #include "error.hh" @@ -230,7 +232,7 @@ protected: } }; -future<> server::verify_signature(const request& req) { +future<> server::verify_signature(const request& req, const chunked_content& content) { if (!_enforce_authorization) { slogger.debug("Skipping authorization"); return make_ready_future<>(); @@ -298,7 +300,7 @@ future<> server::verify_signature(const request& req) { auto cache_getter = [&qp = _qp] (std::string username) { return get_key_from_roles(qp, std::move(username)); }; - return _key_cache.get_ptr(user, cache_getter).then([this, &req, + return _key_cache.get_ptr(user, cache_getter).then([this, &req, &content, user = std::move(user), host = std::move(host), datestamp = std::move(datestamp), @@ -308,7 +310,7 @@ future<> server::verify_signature(const request& req) { service = std::move(service), user_signature = std::move(user_signature)] (key_cache::value_ptr key_ptr) { std::string signature = get_signature(user, *key_ptr, std::string_view(host), req._method, - datestamp, signed_headers_str, signed_headers_map, req.content, region, service, ""); + datestamp, signed_headers_str, signed_headers_map, content, region, service, ""); if (signature != std::string_view(user_signature)) { _key_cache.remove(user); @@ -317,43 +319,96 @@ future<> server::verify_signature(const request& req) { }); } +future server::read_content_and_verify_signature(request& req) { + assert(req.content_stream); + chunked_content content = co_await httpd::read_entire_stream(*req.content_stream); + co_await verify_signature(req, content); + co_return std::move(content); +} + +static tracing::trace_state_ptr create_tracing_session() { + tracing::trace_state_props_set props; + props.set(); + return tracing::tracing::get_local_tracing_instance().create_session(tracing::trace_type::QUERY, props); +} + +// truncated_content_view() prints a potentially long chunked_content for +// debugging purposes. In the common case when the content is not excessively +// long, it just returns a view into the given content, without any copying. +// But when the content is very long, it is truncated after some arbitrary +// max_len (or one chunk, whichever comes first), with "" added at +// the end. To do this modification to the string, we need to create a new +// std::string, so the caller must pass us a reference to one, "buf", where +// we can store the content. The returned view is only alive for as long this +// buf is kept alive. +static std::string_view truncated_content_view(const chunked_content& content, std::string& buf) { + constexpr size_t max_len = 1024; + if (content.empty()) { + return std::string_view(); + } else if (content.size() == 1 && content.begin()->size() <= max_len) { + return std::string_view(content.begin()->get(), content.begin()->size()); + } else { + buf = std::string(content.begin()->get(), std::min(content.begin()->size(), max_len)) + ""; + return std::string_view(buf); + } +} + +static tracing::trace_state_ptr maybe_trace_query(service::client_state& client_state, sstring_view op, const chunked_content& query) { + tracing::trace_state_ptr trace_state; + if (tracing::tracing::get_local_tracing_instance().trace_next_query()) { + trace_state = create_tracing_session(); + std::string buf; + tracing::add_query(trace_state, truncated_content_view(query, buf)); + tracing::begin(trace_state, format("Alternator {}", op), client_state.get_client_address()); + } + return trace_state; +} + future server::handle_api_request(std::unique_ptr&& req) { _executor._stats.total_operations++; sstring target = req->get_header(TARGET); std::vector split_target = split(target, '.'); //NOTICE(sarna): Target consists of Dynamo API version followed by a dot '.' and operation type (e.g. CreateTable) std::string op = split_target.empty() ? std::string() : std::string(split_target.back()); - slogger.trace("Request: {} {} {}", op, req->content, req->_headers); - return verify_signature(*req).then([this, op, req = std::move(req)] () mutable { - auto callback_it = _callbacks.find(op); - if (callback_it == _callbacks.end()) { - _executor._stats.unsupported_operations++; - return make_ready_future(api_error::unknown_operation(format("Unsupported operation {}", op))); - } - if (_pending_requests.get_count() >= _max_concurrent_requests) { - _executor._stats.requests_shed++; - return make_ready_future( + // JSON parsing can allocate up to roughly 2x the size of the raw + // document, + a couple of bytes for maintenance. + // TODO: consider the case where req->content_length is missing. Maybe + // we need to take the content_length_limit and return some of the units + // when we finish read_content_and_verify_signature? + size_t mem_estimate = req->content_length * 2 + 8000; + auto units_fut = get_units(*_memory_limiter, mem_estimate); + if (_memory_limiter->waiters()) { + ++_executor._stats.requests_blocked_memory; + } + return units_fut.then([this, req = std::move(req), op = std::move(op)] (semaphore_units<> units) mutable { + return read_content_and_verify_signature(*req).then([this, op = std::move(op), req = std::move(req), units = std::move(units)] (chunked_content content) mutable { + if (slogger.is_enabled(log_level::trace)) { + std::string buf; + slogger.trace("Request: {} {} {}", op, truncated_content_view(content, buf), req->_headers); + } + auto callback_it = _callbacks.find(op); + if (callback_it == _callbacks.end()) { + _executor._stats.unsupported_operations++; + return make_ready_future( + api_error::unknown_operation(format("Unsupported operation {}", op))); + } + if (_pending_requests.get_count() >= _max_concurrent_requests) { + _executor._stats.requests_shed++; + return make_ready_future( api_error::request_limit_exceeded(format("too many in-flight requests (configured via max_concurrent_requests_per_shard): {}", _pending_requests.get_count()))); - } - return with_gate(_pending_requests, [this, callback_it = std::move(callback_it), op = std::move(op), req = std::move(req)] () mutable { - //FIXME: Client state can provide more context, e.g. client's endpoint address - // We use unique_ptr because client_state cannot be moved or copied - return do_with(std::make_unique(executor::client_state::internal_tag()), - [this, callback_it = std::move(callback_it), op = std::move(op), req = std::move(req)] (std::unique_ptr& client_state) mutable { - tracing::trace_state_ptr trace_state = executor::maybe_trace_query(*client_state, op, req->content); - tracing::trace(trace_state, op); - // JSON parsing can allocate up to roughly 2x the size of the raw document, + a couple of bytes for maintenance. - // FIXME: by this time, the whole HTTP request was already read, so some memory is already occupied. - // Once HTTP allows working on streams, we should grab the permit *before* reading the HTTP payload. - size_t mem_estimate = req->content.size() * 3 + 8000; - auto units_fut = get_units(*_memory_limiter, mem_estimate); - if (_memory_limiter->waiters()) { - ++_executor._stats.requests_blocked_memory; - } - return units_fut.then([this, callback_it = std::move(callback_it), &client_state, trace_state, req = std::move(req)] (semaphore_units<> units) mutable { - return _json_parser.parse(req->content).then([this, callback_it = std::move(callback_it), &client_state, trace_state, + } + return with_gate(_pending_requests, [this, callback_it = std::move(callback_it), op = std::move(op), content = std::move(content), req = std::move(req), units=std::move(units)] () mutable { + //FIXME: Client state can provide more context, e.g. client's endpoint address + // We use unique_ptr because client_state cannot be moved or copied + return do_with(std::make_unique(executor::client_state::internal_tag()), + [this, callback_it = std::move(callback_it), op = std::move(op), content = std::move(content), req = std::move(req), units = std::move(units)] (std::unique_ptr& client_state) mutable { + tracing::trace_state_ptr trace_state = maybe_trace_query(*client_state, op, content); + tracing::trace(trace_state, op); + return _json_parser.parse(std::move(content)).then([this, callback_it = std::move(callback_it), &client_state, trace_state, units = std::move(units), req = std::move(req)] (rjson::value json_request) mutable { - return callback_it->second(_executor, *client_state, trace_state, make_service_permit(std::move(units)), std::move(json_request), std::move(req)).finally([trace_state] {}); + return callback_it->second(_executor, *client_state, trace_state, + make_service_permit(std::move(units)), std::move(json_request), + std::move(req)).finally([trace_state] {}); }); }); }); @@ -478,12 +533,14 @@ future<> server::init(net::inet_address addr, std::optional port, std: if (port) { set_routes(_http_server._routes); _http_server.set_content_length_limit(server::content_length_limit); + _http_server.set_content_streaming(true); _http_server.listen(socket_address{addr, *port}).get(); _enabled_servers.push_back(std::ref(_http_server)); } if (https_port) { set_routes(_https_server._routes); _https_server.set_content_length_limit(server::content_length_limit); + _https_server.set_content_streaming(true); _https_server.set_tls_credentials(creds->build_reloadable_server_credentials([](const std::unordered_set& files, std::exception_ptr ep) { if (ep) { slogger.warn("Exception loading {}: {}", files, ep); @@ -521,7 +578,7 @@ server::json_parser::json_parser() : _run_parse_json_thread(async([this] { return; } try { - _parsed_document = rjson::parse_yieldable(_raw_document); + _parsed_document = rjson::parse_yieldable(std::move(_raw_document)); _current_exception = nullptr; } catch (...) { _current_exception = std::current_exception(); @@ -531,12 +588,12 @@ server::json_parser::json_parser() : _run_parse_json_thread(async([this] { })) { } -future server::json_parser::parse(std::string_view content) { +future server::json_parser::parse(chunked_content&& content) { if (content.size() < yieldable_parsing_threshold) { - return make_ready_future(rjson::parse(content)); + return make_ready_future(rjson::parse(std::move(content))); } - return with_semaphore(_parsing_sem, 1, [this, content] { - _raw_document = content; + return with_semaphore(_parsing_sem, 1, [this, content = std::move(content)] () mutable { + _raw_document = std::move(content); _document_waiting.signal(); return _document_parsed.wait().then([this] { if (_current_exception) { diff --git a/alternator/server.hh b/alternator/server.hh index 3c43ca4997..17d039b970 100644 --- a/alternator/server.hh +++ b/alternator/server.hh @@ -33,6 +33,8 @@ namespace alternator { +using chunked_content = rjson::chunked_content; + class server { static constexpr size_t content_length_limit = 16*MB; using alternator_callback = std::function(executor&, executor::client_state&, @@ -55,7 +57,7 @@ class server { class json_parser { static constexpr size_t yieldable_parsing_threshold = 16*KB; - std::string_view _raw_document; + chunked_content _raw_document; rjson::value _parsed_document; std::exception_ptr _current_exception; semaphore _parsing_sem{1}; @@ -65,7 +67,10 @@ class server { future<> _run_parse_json_thread; public: json_parser(); - future parse(std::string_view content); + // Moving a chunked_content into parse() allows parse() to free each + // chunk as soon as it is parsed, so when chunks are relatively small, + // we don't need to store the sum of unparsed and parsed sizes. + future parse(chunked_content&& content); future<> stop(); }; json_parser _json_parser; @@ -78,7 +83,8 @@ public: future<> stop(); private: void set_routes(seastar::httpd::routes& r); - future<> verify_signature(const seastar::httpd::request& r); + future<> verify_signature(const seastar::httpd::request&, const chunked_content&); + future read_content_and_verify_signature(seastar::httpd::request&); future handle_api_request(std::unique_ptr&& req); }; diff --git a/test/alternator/test_batch.py b/test/alternator/test_batch.py index 5e2d586080..61e275d780 100644 --- a/test/alternator/test_batch.py +++ b/test/alternator/test_batch.py @@ -318,3 +318,19 @@ def test_batch_unprocessed(test_table_s): test_table_s.name: {'Keys': [{'p': p}], 'ProjectionExpression': 'p, a', 'ConsistentRead': True} }) assert 'UnprocessedKeys' in read_reply and read_reply['UnprocessedKeys'] == dict() + +# According to the DynamoDB document, a single BatchWriteItem operation is +# limited to 25 update requests, up to 400 KB each, or 16 MB total (25*400 +# is only 10 MB, but the JSON format has additional overheads). If we write +# less than those limits in a single BatchWriteItem operation, it should +# work. Testing a large request exercises our code which calculates the +# request signature, and parses a long request (issue #7213). +def test_batch_write_item_large(test_table_sn): + p = random_string() + long_content = random_string(100)*500 + write_reply = test_table_sn.meta.client.batch_write_item(RequestItems = { + test_table_sn.name: [{'PutRequest': {'Item': {'p': p, 'c': i, 'content': long_content}}} for i in range(25)], + }) + assert 'UnprocessedItems' in write_reply and write_reply['UnprocessedItems'] == dict() + assert full_query(test_table_sn, KeyConditionExpression='p=:p', ExpressionAttributeValues={':p': p} + ) == [{'p': p, 'c': i, 'content': long_content} for i in range(25)] diff --git a/utils/rjson.cc b/utils/rjson.cc index a5d4db1b18..0e677239f2 100644 --- a/utils/rjson.cc +++ b/utils/rjson.cc @@ -27,6 +27,66 @@ namespace rjson { allocator the_allocator; +// chunked_content_stream is a wrapper of a chunked_content which +// presents the Stream concept that the rapidjson library expects as input +// for its parser (https://rapidjson.org/classrapidjson_1_1_stream.html). +// This wrapper owns the chunked_content, so it can free each chunk as +// soon as it's parsed. +class chunked_content_stream { +private: + chunked_content _content; + chunked_content::iterator _current_chunk; + // _count only needed for Tell(). 32 bits is enough, we don't allow + // more than 16 MB requests anyway. + unsigned _count; +public: + typedef char Ch; + chunked_content_stream(chunked_content&& content) + : _content(std::move(content)) + , _current_chunk(_content.begin()) + {} + bool eof() const { + return _current_chunk == _content.end(); + } + // Methods needed by rapidjson's Stream concept (see + // https://rapidjson.org/classrapidjson_1_1_stream.html): + char Peek() const { + if (eof()) { + // Rapidjson's Stream concept does not have the explicit notion of + // an "end of file". Instead, reading after the end of stream will + // return a null byte. This makes these streams appear like null- + // terminated C strings. It is good enough for reading JSON, which + // anyway can't include bare null characters. + return '\0'; + } else { + return *_current_chunk->begin(); + } + } + char Take() { + if (eof()) { + return '\0'; + } else { + char ret = *_current_chunk->begin(); + _current_chunk->trim_front(1); + ++_count; + if (_current_chunk->empty()) { + *_current_chunk = temporary_buffer(); + ++_current_chunk; + } + return ret; + } + } + size_t Tell() const { + return _count; + } + // Not used in input streams, but unfortunately we still need to implement + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +}; + /* * This wrapper class adds nested level checks to rapidjson's handlers. * Each rapidjson handler implements functions for accepting JSON values, @@ -49,11 +109,11 @@ public: guarded_yieldable_json_handler(string_buffer& buf, size_t max_nested_level) : handler_base(buf), _max_nested_level(max_nested_level) {} - void Parse(const char* str, size_t length) { - rapidjson::MemoryStream ms(static_cast(str), length * sizeof(typename encoding::Ch)); - rapidjson::EncodedInputStream is(ms); + // Parse any stream fitting https://rapidjson.org/classrapidjson_1_1_stream.html + template + void Parse(Stream& stream) { rapidjson::GenericReader reader(&the_allocator); - reader.Parse(is, *this); + reader.Parse(stream, *this); if (reader.HasParseError()) { throw rjson::error(format("Parsing JSON failed: {}", rapidjson::GetParseError_En(reader.GetParseErrorCode()))); } @@ -70,6 +130,18 @@ public: auto dummy_generator = [](handler_base&){return true;}; handler_base::Populate(dummy_generator); } + void Parse(const char* str, size_t length) { + rapidjson::MemoryStream ms(static_cast(str), length * sizeof(typename encoding::Ch)); + rapidjson::EncodedInputStream is(ms); + Parse(is); + } + + void Parse(chunked_content&& content) { + // Note that content was moved into this function. The intention is + // that we free every chunk we are done with. + chunked_content_stream is(std::move(content)); + Parse(is); + } bool StartObject() { ++_nested_level; @@ -156,6 +228,16 @@ rjson::value parse(std::string_view str) { return std::move(v); } +rjson::value parse(chunked_content&& content) { + guarded_yieldable_json_handler d(78); + d.Parse(std::move(content)); + if (d.HasParseError()) { + throw rjson::error(format("Parsing JSON failed: {}", GetParseError_En(d.GetParseError()))); + } + rjson::value& v = d; + return std::move(v); +} + std::optional try_parse(std::string_view str) { guarded_yieldable_json_handler d(78); try { @@ -180,6 +262,16 @@ rjson::value parse_yieldable(std::string_view str) { return std::move(v); } +rjson::value parse_yieldable(chunked_content&& content) { + guarded_yieldable_json_handler d(78); + d.Parse(std::move(content)); + if (d.HasParseError()) { + throw rjson::error(format("Parsing JSON failed: {}", GetParseError_En(d.GetParseError()))); + } + rjson::value& v = d; + return std::move(v); +} + rjson::value& get(rjson::value& value, std::string_view name) { // Although FindMember() has a variant taking a StringRef, it ignores the // given length (see https://github.com/Tencent/rapidjson/issues/1649). diff --git a/utils/rjson.hh b/utils/rjson.hh index 5803f6fff9..912335b48e 100644 --- a/utils/rjson.hh +++ b/utils/rjson.hh @@ -141,6 +141,18 @@ std::optional try_parse(std::string_view str); // Needs to be run in thread context rjson::value parse_yieldable(std::string_view str); +// chunked_content holds a non-contiguous buffer of bytes - such as bytes +// read by httpd::read_entire_stream(). We assume that chunked_content does +// not contain any empty buffers (the vector can be empty, meaning empty +// content - but individual buffers cannot). +using chunked_content = std::vector>; + +// Additional variants of parse() and parse_yieldable() that work on non- +// contiguous chunked_content. The chunked_content is moved into the parsing +// function so that we can start freeing chunks as soon as we parse them. +rjson::value parse(chunked_content&&); +rjson::value parse_yieldable(chunked_content&&); + // Creates a JSON value (of JSON string type) out of internal string representations. // The string value is copied, so str's liveness does not need to be persisted. rjson::value from_string(const char* str, size_t size);