diff --git a/alternator/auth.cc b/alternator/auth.cc index 75f6e05afb..623d829d94 100644 --- a/alternator/auth.cc +++ b/alternator/auth.cc @@ -62,6 +62,14 @@ static std::string apply_sha256(std::string_view msg) { return to_hex(hasher.finalize()); } +static std::string apply_sha256(const std::vector>& msg) { + sha256_hasher hasher; + for (const temporary_buffer& buf : msg) { + hasher.update(buf.get(), buf.size()); + } + return to_hex(hasher.finalize()); +} + static std::string format_time_point(db_clock::time_point tp) { time_t time_point_repr = db_clock::to_time_t(tp); std::string time_point_str; @@ -91,7 +99,7 @@ void check_expiry(std::string_view signature_date) { std::string get_signature(std::string_view access_key_id, std::string_view secret_access_key, std::string_view host, std::string_view method, std::string_view orig_datestamp, std::string_view signed_headers_str, const std::map& signed_headers_map, - std::string_view body_content, std::string_view region, std::string_view service, std::string_view query_string) { + const std::vector>& body_content, std::string_view region, std::string_view service, std::string_view query_string) { auto amz_date_it = signed_headers_map.find("x-amz-date"); if (amz_date_it == signed_headers_map.end()) { throw api_error::invalid_signature("X-Amz-Date header is mandatory for signature verification"); diff --git a/alternator/auth.hh b/alternator/auth.hh index d3fa02c403..25fd506368 100644 --- a/alternator/auth.hh +++ b/alternator/auth.hh @@ -39,7 +39,7 @@ using key_cache = utils::loading_cache; std::string get_signature(std::string_view access_key_id, std::string_view secret_access_key, std::string_view host, std::string_view method, std::string_view orig_datestamp, std::string_view signed_headers_str, const std::map& signed_headers_map, - std::string_view body_content, std::string_view region, std::string_view service, std::string_view query_string); + const std::vector>& body_content, std::string_view region, std::string_view service, std::string_view query_string); future get_key_from_roles(cql3::query_processor& qp, std::string username); diff --git a/alternator/executor.cc b/alternator/executor.cc index fe6f0bf596..236f1b77c1 100644 --- a/alternator/executor.cc +++ b/alternator/executor.cc @@ -3909,22 +3909,6 @@ future<> executor::create_keyspace(std::string_view keyspace_name) { }); } -static tracing::trace_state_ptr create_tracing_session() { - tracing::trace_state_props_set props; - props.set(); - return tracing::tracing::get_local_tracing_instance().create_session(tracing::trace_type::QUERY, props); -} - -tracing::trace_state_ptr executor::maybe_trace_query(client_state& client_state, sstring_view op, sstring_view query) { - tracing::trace_state_ptr trace_state; - if (tracing::tracing::get_local_tracing_instance().trace_next_query()) { - trace_state = create_tracing_session(); - tracing::add_query(trace_state, query); - tracing::begin(trace_state, format("Alternator {}", op), client_state.get_client_address()); - } - return trace_state; -} - future<> executor::start() { // Currently, nothing to do on initialization. We delay the keyspace // creation (create_keyspace()) until a table is actually created. diff --git a/alternator/executor.hh b/alternator/executor.hh index 6a89ded936..bbe84910e6 100644 --- a/alternator/executor.hh +++ b/alternator/executor.hh @@ -187,8 +187,6 @@ public: future<> create_keyspace(std::string_view keyspace_name); - static tracing::trace_state_ptr maybe_trace_query(client_state& client_state, sstring_view op, sstring_view query); - static sstring table_name(const schema&); static db::timeout_clock::time_point default_timeout(); static void set_default_timeout(db::timeout_clock::duration timeout); diff --git a/alternator/server.cc b/alternator/server.cc index 30b65f0cbe..c3146da4da 100644 --- a/alternator/server.cc +++ b/alternator/server.cc @@ -22,6 +22,8 @@ #include "alternator/server.hh" #include "log.hh" #include +#include +#include #include #include "seastarx.hh" #include "error.hh" @@ -230,7 +232,7 @@ protected: } }; -future<> server::verify_signature(const request& req) { +future<> server::verify_signature(const request& req, const chunked_content& content) { if (!_enforce_authorization) { slogger.debug("Skipping authorization"); return make_ready_future<>(); @@ -298,7 +300,7 @@ future<> server::verify_signature(const request& req) { auto cache_getter = [&qp = _qp] (std::string username) { return get_key_from_roles(qp, std::move(username)); }; - return _key_cache.get_ptr(user, cache_getter).then([this, &req, + return _key_cache.get_ptr(user, cache_getter).then([this, &req, &content, user = std::move(user), host = std::move(host), datestamp = std::move(datestamp), @@ -308,7 +310,7 @@ future<> server::verify_signature(const request& req) { service = std::move(service), user_signature = std::move(user_signature)] (key_cache::value_ptr key_ptr) { std::string signature = get_signature(user, *key_ptr, std::string_view(host), req._method, - datestamp, signed_headers_str, signed_headers_map, req.content, region, service, ""); + datestamp, signed_headers_str, signed_headers_map, content, region, service, ""); if (signature != std::string_view(user_signature)) { _key_cache.remove(user); @@ -317,43 +319,96 @@ future<> server::verify_signature(const request& req) { }); } +future server::read_content_and_verify_signature(request& req) { + assert(req.content_stream); + chunked_content content = co_await httpd::read_entire_stream(*req.content_stream); + co_await verify_signature(req, content); + co_return std::move(content); +} + +static tracing::trace_state_ptr create_tracing_session() { + tracing::trace_state_props_set props; + props.set(); + return tracing::tracing::get_local_tracing_instance().create_session(tracing::trace_type::QUERY, props); +} + +// truncated_content_view() prints a potentially long chunked_content for +// debugging purposes. In the common case when the content is not excessively +// long, it just returns a view into the given content, without any copying. +// But when the content is very long, it is truncated after some arbitrary +// max_len (or one chunk, whichever comes first), with "" added at +// the end. To do this modification to the string, we need to create a new +// std::string, so the caller must pass us a reference to one, "buf", where +// we can store the content. The returned view is only alive for as long this +// buf is kept alive. +static std::string_view truncated_content_view(const chunked_content& content, std::string& buf) { + constexpr size_t max_len = 1024; + if (content.empty()) { + return std::string_view(); + } else if (content.size() == 1 && content.begin()->size() <= max_len) { + return std::string_view(content.begin()->get(), content.begin()->size()); + } else { + buf = std::string(content.begin()->get(), std::min(content.begin()->size(), max_len)) + ""; + return std::string_view(buf); + } +} + +static tracing::trace_state_ptr maybe_trace_query(service::client_state& client_state, sstring_view op, const chunked_content& query) { + tracing::trace_state_ptr trace_state; + if (tracing::tracing::get_local_tracing_instance().trace_next_query()) { + trace_state = create_tracing_session(); + std::string buf; + tracing::add_query(trace_state, truncated_content_view(query, buf)); + tracing::begin(trace_state, format("Alternator {}", op), client_state.get_client_address()); + } + return trace_state; +} + future server::handle_api_request(std::unique_ptr&& req) { _executor._stats.total_operations++; sstring target = req->get_header(TARGET); std::vector split_target = split(target, '.'); //NOTICE(sarna): Target consists of Dynamo API version followed by a dot '.' and operation type (e.g. CreateTable) std::string op = split_target.empty() ? std::string() : std::string(split_target.back()); - slogger.trace("Request: {} {} {}", op, req->content, req->_headers); - return verify_signature(*req).then([this, op, req = std::move(req)] () mutable { - auto callback_it = _callbacks.find(op); - if (callback_it == _callbacks.end()) { - _executor._stats.unsupported_operations++; - return make_ready_future(api_error::unknown_operation(format("Unsupported operation {}", op))); - } - if (_pending_requests.get_count() >= _max_concurrent_requests) { - _executor._stats.requests_shed++; - return make_ready_future( + // JSON parsing can allocate up to roughly 2x the size of the raw + // document, + a couple of bytes for maintenance. + // TODO: consider the case where req->content_length is missing. Maybe + // we need to take the content_length_limit and return some of the units + // when we finish read_content_and_verify_signature? + size_t mem_estimate = req->content_length * 2 + 8000; + auto units_fut = get_units(*_memory_limiter, mem_estimate); + if (_memory_limiter->waiters()) { + ++_executor._stats.requests_blocked_memory; + } + return units_fut.then([this, req = std::move(req), op = std::move(op)] (semaphore_units<> units) mutable { + return read_content_and_verify_signature(*req).then([this, op = std::move(op), req = std::move(req), units = std::move(units)] (chunked_content content) mutable { + if (slogger.is_enabled(log_level::trace)) { + std::string buf; + slogger.trace("Request: {} {} {}", op, truncated_content_view(content, buf), req->_headers); + } + auto callback_it = _callbacks.find(op); + if (callback_it == _callbacks.end()) { + _executor._stats.unsupported_operations++; + return make_ready_future( + api_error::unknown_operation(format("Unsupported operation {}", op))); + } + if (_pending_requests.get_count() >= _max_concurrent_requests) { + _executor._stats.requests_shed++; + return make_ready_future( api_error::request_limit_exceeded(format("too many in-flight requests (configured via max_concurrent_requests_per_shard): {}", _pending_requests.get_count()))); - } - return with_gate(_pending_requests, [this, callback_it = std::move(callback_it), op = std::move(op), req = std::move(req)] () mutable { - //FIXME: Client state can provide more context, e.g. client's endpoint address - // We use unique_ptr because client_state cannot be moved or copied - return do_with(std::make_unique(executor::client_state::internal_tag()), - [this, callback_it = std::move(callback_it), op = std::move(op), req = std::move(req)] (std::unique_ptr& client_state) mutable { - tracing::trace_state_ptr trace_state = executor::maybe_trace_query(*client_state, op, req->content); - tracing::trace(trace_state, op); - // JSON parsing can allocate up to roughly 2x the size of the raw document, + a couple of bytes for maintenance. - // FIXME: by this time, the whole HTTP request was already read, so some memory is already occupied. - // Once HTTP allows working on streams, we should grab the permit *before* reading the HTTP payload. - size_t mem_estimate = req->content.size() * 3 + 8000; - auto units_fut = get_units(*_memory_limiter, mem_estimate); - if (_memory_limiter->waiters()) { - ++_executor._stats.requests_blocked_memory; - } - return units_fut.then([this, callback_it = std::move(callback_it), &client_state, trace_state, req = std::move(req)] (semaphore_units<> units) mutable { - return _json_parser.parse(req->content).then([this, callback_it = std::move(callback_it), &client_state, trace_state, + } + return with_gate(_pending_requests, [this, callback_it = std::move(callback_it), op = std::move(op), content = std::move(content), req = std::move(req), units=std::move(units)] () mutable { + //FIXME: Client state can provide more context, e.g. client's endpoint address + // We use unique_ptr because client_state cannot be moved or copied + return do_with(std::make_unique(executor::client_state::internal_tag()), + [this, callback_it = std::move(callback_it), op = std::move(op), content = std::move(content), req = std::move(req), units = std::move(units)] (std::unique_ptr& client_state) mutable { + tracing::trace_state_ptr trace_state = maybe_trace_query(*client_state, op, content); + tracing::trace(trace_state, op); + return _json_parser.parse(std::move(content)).then([this, callback_it = std::move(callback_it), &client_state, trace_state, units = std::move(units), req = std::move(req)] (rjson::value json_request) mutable { - return callback_it->second(_executor, *client_state, trace_state, make_service_permit(std::move(units)), std::move(json_request), std::move(req)).finally([trace_state] {}); + return callback_it->second(_executor, *client_state, trace_state, + make_service_permit(std::move(units)), std::move(json_request), + std::move(req)).finally([trace_state] {}); }); }); }); @@ -478,12 +533,14 @@ future<> server::init(net::inet_address addr, std::optional port, std: if (port) { set_routes(_http_server._routes); _http_server.set_content_length_limit(server::content_length_limit); + _http_server.set_content_streaming(true); _http_server.listen(socket_address{addr, *port}).get(); _enabled_servers.push_back(std::ref(_http_server)); } if (https_port) { set_routes(_https_server._routes); _https_server.set_content_length_limit(server::content_length_limit); + _https_server.set_content_streaming(true); _https_server.set_tls_credentials(creds->build_reloadable_server_credentials([](const std::unordered_set& files, std::exception_ptr ep) { if (ep) { slogger.warn("Exception loading {}: {}", files, ep); @@ -521,7 +578,7 @@ server::json_parser::json_parser() : _run_parse_json_thread(async([this] { return; } try { - _parsed_document = rjson::parse_yieldable(_raw_document); + _parsed_document = rjson::parse_yieldable(std::move(_raw_document)); _current_exception = nullptr; } catch (...) { _current_exception = std::current_exception(); @@ -531,12 +588,12 @@ server::json_parser::json_parser() : _run_parse_json_thread(async([this] { })) { } -future server::json_parser::parse(std::string_view content) { +future server::json_parser::parse(chunked_content&& content) { if (content.size() < yieldable_parsing_threshold) { - return make_ready_future(rjson::parse(content)); + return make_ready_future(rjson::parse(std::move(content))); } - return with_semaphore(_parsing_sem, 1, [this, content] { - _raw_document = content; + return with_semaphore(_parsing_sem, 1, [this, content = std::move(content)] () mutable { + _raw_document = std::move(content); _document_waiting.signal(); return _document_parsed.wait().then([this] { if (_current_exception) { diff --git a/alternator/server.hh b/alternator/server.hh index 3c43ca4997..17d039b970 100644 --- a/alternator/server.hh +++ b/alternator/server.hh @@ -33,6 +33,8 @@ namespace alternator { +using chunked_content = rjson::chunked_content; + class server { static constexpr size_t content_length_limit = 16*MB; using alternator_callback = std::function(executor&, executor::client_state&, @@ -55,7 +57,7 @@ class server { class json_parser { static constexpr size_t yieldable_parsing_threshold = 16*KB; - std::string_view _raw_document; + chunked_content _raw_document; rjson::value _parsed_document; std::exception_ptr _current_exception; semaphore _parsing_sem{1}; @@ -65,7 +67,10 @@ class server { future<> _run_parse_json_thread; public: json_parser(); - future parse(std::string_view content); + // Moving a chunked_content into parse() allows parse() to free each + // chunk as soon as it is parsed, so when chunks are relatively small, + // we don't need to store the sum of unparsed and parsed sizes. + future parse(chunked_content&& content); future<> stop(); }; json_parser _json_parser; @@ -78,7 +83,8 @@ public: future<> stop(); private: void set_routes(seastar::httpd::routes& r); - future<> verify_signature(const seastar::httpd::request& r); + future<> verify_signature(const seastar::httpd::request&, const chunked_content&); + future read_content_and_verify_signature(seastar::httpd::request&); future handle_api_request(std::unique_ptr&& req); }; diff --git a/test/alternator/test_batch.py b/test/alternator/test_batch.py index 5e2d586080..61e275d780 100644 --- a/test/alternator/test_batch.py +++ b/test/alternator/test_batch.py @@ -318,3 +318,19 @@ def test_batch_unprocessed(test_table_s): test_table_s.name: {'Keys': [{'p': p}], 'ProjectionExpression': 'p, a', 'ConsistentRead': True} }) assert 'UnprocessedKeys' in read_reply and read_reply['UnprocessedKeys'] == dict() + +# According to the DynamoDB document, a single BatchWriteItem operation is +# limited to 25 update requests, up to 400 KB each, or 16 MB total (25*400 +# is only 10 MB, but the JSON format has additional overheads). If we write +# less than those limits in a single BatchWriteItem operation, it should +# work. Testing a large request exercises our code which calculates the +# request signature, and parses a long request (issue #7213). +def test_batch_write_item_large(test_table_sn): + p = random_string() + long_content = random_string(100)*500 + write_reply = test_table_sn.meta.client.batch_write_item(RequestItems = { + test_table_sn.name: [{'PutRequest': {'Item': {'p': p, 'c': i, 'content': long_content}}} for i in range(25)], + }) + assert 'UnprocessedItems' in write_reply and write_reply['UnprocessedItems'] == dict() + assert full_query(test_table_sn, KeyConditionExpression='p=:p', ExpressionAttributeValues={':p': p} + ) == [{'p': p, 'c': i, 'content': long_content} for i in range(25)] diff --git a/utils/rjson.cc b/utils/rjson.cc index a5d4db1b18..0e677239f2 100644 --- a/utils/rjson.cc +++ b/utils/rjson.cc @@ -27,6 +27,66 @@ namespace rjson { allocator the_allocator; +// chunked_content_stream is a wrapper of a chunked_content which +// presents the Stream concept that the rapidjson library expects as input +// for its parser (https://rapidjson.org/classrapidjson_1_1_stream.html). +// This wrapper owns the chunked_content, so it can free each chunk as +// soon as it's parsed. +class chunked_content_stream { +private: + chunked_content _content; + chunked_content::iterator _current_chunk; + // _count only needed for Tell(). 32 bits is enough, we don't allow + // more than 16 MB requests anyway. + unsigned _count; +public: + typedef char Ch; + chunked_content_stream(chunked_content&& content) + : _content(std::move(content)) + , _current_chunk(_content.begin()) + {} + bool eof() const { + return _current_chunk == _content.end(); + } + // Methods needed by rapidjson's Stream concept (see + // https://rapidjson.org/classrapidjson_1_1_stream.html): + char Peek() const { + if (eof()) { + // Rapidjson's Stream concept does not have the explicit notion of + // an "end of file". Instead, reading after the end of stream will + // return a null byte. This makes these streams appear like null- + // terminated C strings. It is good enough for reading JSON, which + // anyway can't include bare null characters. + return '\0'; + } else { + return *_current_chunk->begin(); + } + } + char Take() { + if (eof()) { + return '\0'; + } else { + char ret = *_current_chunk->begin(); + _current_chunk->trim_front(1); + ++_count; + if (_current_chunk->empty()) { + *_current_chunk = temporary_buffer(); + ++_current_chunk; + } + return ret; + } + } + size_t Tell() const { + return _count; + } + // Not used in input streams, but unfortunately we still need to implement + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +}; + /* * This wrapper class adds nested level checks to rapidjson's handlers. * Each rapidjson handler implements functions for accepting JSON values, @@ -49,11 +109,11 @@ public: guarded_yieldable_json_handler(string_buffer& buf, size_t max_nested_level) : handler_base(buf), _max_nested_level(max_nested_level) {} - void Parse(const char* str, size_t length) { - rapidjson::MemoryStream ms(static_cast(str), length * sizeof(typename encoding::Ch)); - rapidjson::EncodedInputStream is(ms); + // Parse any stream fitting https://rapidjson.org/classrapidjson_1_1_stream.html + template + void Parse(Stream& stream) { rapidjson::GenericReader reader(&the_allocator); - reader.Parse(is, *this); + reader.Parse(stream, *this); if (reader.HasParseError()) { throw rjson::error(format("Parsing JSON failed: {}", rapidjson::GetParseError_En(reader.GetParseErrorCode()))); } @@ -70,6 +130,18 @@ public: auto dummy_generator = [](handler_base&){return true;}; handler_base::Populate(dummy_generator); } + void Parse(const char* str, size_t length) { + rapidjson::MemoryStream ms(static_cast(str), length * sizeof(typename encoding::Ch)); + rapidjson::EncodedInputStream is(ms); + Parse(is); + } + + void Parse(chunked_content&& content) { + // Note that content was moved into this function. The intention is + // that we free every chunk we are done with. + chunked_content_stream is(std::move(content)); + Parse(is); + } bool StartObject() { ++_nested_level; @@ -156,6 +228,16 @@ rjson::value parse(std::string_view str) { return std::move(v); } +rjson::value parse(chunked_content&& content) { + guarded_yieldable_json_handler d(78); + d.Parse(std::move(content)); + if (d.HasParseError()) { + throw rjson::error(format("Parsing JSON failed: {}", GetParseError_En(d.GetParseError()))); + } + rjson::value& v = d; + return std::move(v); +} + std::optional try_parse(std::string_view str) { guarded_yieldable_json_handler d(78); try { @@ -180,6 +262,16 @@ rjson::value parse_yieldable(std::string_view str) { return std::move(v); } +rjson::value parse_yieldable(chunked_content&& content) { + guarded_yieldable_json_handler d(78); + d.Parse(std::move(content)); + if (d.HasParseError()) { + throw rjson::error(format("Parsing JSON failed: {}", GetParseError_En(d.GetParseError()))); + } + rjson::value& v = d; + return std::move(v); +} + rjson::value& get(rjson::value& value, std::string_view name) { // Although FindMember() has a variant taking a StringRef, it ignores the // given length (see https://github.com/Tencent/rapidjson/issues/1649). diff --git a/utils/rjson.hh b/utils/rjson.hh index 5803f6fff9..912335b48e 100644 --- a/utils/rjson.hh +++ b/utils/rjson.hh @@ -141,6 +141,18 @@ std::optional try_parse(std::string_view str); // Needs to be run in thread context rjson::value parse_yieldable(std::string_view str); +// chunked_content holds a non-contiguous buffer of bytes - such as bytes +// read by httpd::read_entire_stream(). We assume that chunked_content does +// not contain any empty buffers (the vector can be empty, meaning empty +// content - but individual buffers cannot). +using chunked_content = std::vector>; + +// Additional variants of parse() and parse_yieldable() that work on non- +// contiguous chunked_content. The chunked_content is moved into the parsing +// function so that we can start freeing chunks as soon as we parse them. +rjson::value parse(chunked_content&&); +rjson::value parse_yieldable(chunked_content&&); + // Creates a JSON value (of JSON string type) out of internal string representations. // The string value is copied, so str's liveness does not need to be persisted. rjson::value from_string(const char* str, size_t size);