From c00814383df10189b8912966828274ac84397dbf Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Tue, 16 May 2017 14:31:54 -0400 Subject: [PATCH 1/5] cql_server::response: store the frame flags inside the class It makes a lot more sense to keep the flags mask inside the response and update it each time the corresponding feature is set instead of holding the separate components like tracing state pointer. This patch adds this ability to set the flags. Signed-off-by: Vlad Zolotarov --- transport/server.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index 7b41de047e..3bd0c7929f 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -186,6 +186,7 @@ cql_load_balance parse_load_balance(sstring value) class cql_server::response { int16_t _stream; cql_binary_opcode _opcode; + uint8_t _flags = 0; // a bitwise OR mask of zero or more cql_frame_flags values std::experimental::optional _tracing_id; std::vector _body; public: @@ -194,6 +195,10 @@ public: , _opcode{opcode} { } + void set_frame_flag(cql_frame_flags flag) noexcept { + _flags |= flag; + } + void set_tracing_id(const utils::UUID& id) { _tracing_id = id; } @@ -231,20 +236,20 @@ private: std::vector compress_snappy(const std::vector& body); template - sstring make_frame_one(uint8_t version, uint8_t flags, size_t length) { + sstring make_frame_one(uint8_t version, size_t length) { size_t extra_len = 0; // If tracing was requested the response should contain a "tracing // session ID" which is a 16 bytes UUID. if (_tracing_id) { extra_len += 16; - flags |= cql_frame_flags::tracing; + set_frame_flag(cql_frame_flags::tracing); } sstring frame_buf(sstring::initialized_later(), sizeof(CqlFrameHeaderType) + extra_len); auto* frame = reinterpret_cast(frame_buf.begin()); frame->version = version | 0x80; - frame->flags = flags; + frame->flags = _flags; frame->opcode = static_cast(_opcode); frame->length = htonl(length + extra_len); frame->stream = net::hton((decltype(frame->stream))_stream); @@ -257,15 +262,15 @@ private: return frame_buf; } - sstring make_frame(uint8_t version, uint8_t flags, size_t length) { + sstring make_frame(uint8_t version, size_t length) { if (version > 0x04) { throw exceptions::protocol_exception(sprint("Invalid or unsupported protocol version: %d", version)); } if (version > 0x02) { - return make_frame_one(version, flags, length); + return make_frame_one(version, length); } else { - return make_frame_one(version, flags, length); + return make_frame_one(version, length); } } }; @@ -1496,7 +1501,7 @@ cql3::raw_value_view cql_server::connection::read_value_view(bytes_view& buf) { scattered_message cql_server::response::make_message(uint8_t version) { scattered_message msg; sstring body{_body.data(), _body.size()}; - sstring frame = make_frame(version, 0x00, body.size()); + sstring frame = make_frame(version, _body.size()); msg.append(std::move(frame)); msg.append(std::move(body)); return msg; @@ -1504,12 +1509,11 @@ scattered_message cql_server::response::make_message(uint8_t version) { future<> cql_server::response::output(output_stream& out, uint8_t version, cql_compression compression) { - uint8_t flags = 0; if (compression != cql_compression::none) { - flags |= cql_frame_flags::compression; _body = compress(_body, compression); + set_frame_flag(cql_frame_flags::compression); } - auto frame = make_frame(version, flags, _body.size()); + auto frame = make_frame(version, _body.size()); auto tmp = temporary_buffer(frame.size()); std::copy_n(frame.begin(), frame.size(), tmp.get_write()); auto f = out.write(tmp.get(), tmp.size()); From a33fe5b7755f79cece56fb7ed0a090f5f319817b Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Tue, 16 May 2017 15:53:35 -0400 Subject: [PATCH 2/5] cql_server::response: rework the compress(...) method Cleanup the compress(...) method interface: - Encapsulate the technical details inside the method: - Re-write the _body inside the method instead of returning it. - Set the response::_flags inside the method. Signed-off-by: Vlad Zolotarov --- transport/server.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/transport/server.cc b/transport/server.cc index 3bd0c7929f..d25bbf9ebc 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -231,7 +231,7 @@ public: return _opcode; } private: - std::vector compress(const std::vector& body, cql_compression compression); + void compress(cql_compression compression); std::vector compress_lz4(const std::vector& body); std::vector compress_snappy(const std::vector& body); @@ -1510,8 +1510,7 @@ scattered_message cql_server::response::make_message(uint8_t version) { future<> cql_server::response::output(output_stream& out, uint8_t version, cql_compression compression) { if (compression != cql_compression::none) { - _body = compress(_body, compression); - set_frame_flag(cql_frame_flags::compression); + compress(compression); } auto frame = make_frame(version, _body.size()); auto tmp = temporary_buffer(frame.size()); @@ -1522,13 +1521,19 @@ cql_server::response::output(output_stream& out, uint8_t version, cql_comp }); } -std::vector cql_server::response::compress(const std::vector& body, cql_compression compression) +void cql_server::response::compress(cql_compression compression) { switch (compression) { - case cql_compression::lz4: return compress_lz4(body); - case cql_compression::snappy: return compress_snappy(body); - default: throw std::invalid_argument("Invalid CQL compression algorithm"); + case cql_compression::lz4: + _body = compress_lz4(_body); + break; + case cql_compression::snappy: + _body = compress_snappy(_body); + break; + default: + throw std::invalid_argument("Invalid CQL compression algorithm"); } + set_frame_flag(cql_frame_flags::compression); } std::vector cql_server::response::compress_lz4(const std::vector& body) From 7706775a634507cceb9e0253285da9af91a207dd Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Sat, 13 May 2017 14:16:54 -0400 Subject: [PATCH 3/5] utils: serialization: unify the variety of serialize_XXX(...) Use the same templated implementation for all different serialize_XXX(...). The chosen implementation is based on the std::copy_n(char*, size, OutputIterator), which is heavily optimized and will be using memcpy/memmove where possible. This patch also removes the not needed specializations that accept signed integer values since we were casting them to unsigned value anyway. The std::ostream based specifications are also removed since they are not used anywhere except for a test-serialization.cc and adjusting the ostream to the iterator is a single-liner. Signed-off-by: Vlad Zolotarov --- tests/test-serialization.cc | 34 ++++--- utils/serialization.hh | 185 ++++++++++++------------------------ 2 files changed, 84 insertions(+), 135 deletions(-) diff --git a/tests/test-serialization.cc b/tests/test-serialization.cc index ba99577aaa..2a2fa915e4 100644 --- a/tests/test-serialization.cc +++ b/tests/test-serialization.cc @@ -49,31 +49,35 @@ void show(std::stringstream &ss) { // is rotten" type of translation :-) int8_t back_and_forth_8(int8_t a) { std::stringstream buf; - serialize_int8(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int8(it, a); return deserialize_int8(buf); } int16_t back_and_forth_16(int16_t a) { std::stringstream buf; - serialize_int16(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int16(it, a); return deserialize_int16(buf); } int32_t back_and_forth_32(int32_t a) { std::stringstream buf; - serialize_int32(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int32(it, a); return deserialize_int32(buf); } int64_t back_and_forth_64(int64_t a) { std::stringstream buf; - serialize_int64(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int64(it, a); return deserialize_int64(buf); } sstring back_and_forth_sstring(sstring a) { std::stringstream buf; - serialize_string(buf, a); + auto it = std::ostream_iterator(buf); + serialize_string(it, a); return deserialize_string(buf); } BOOST_AUTO_TEST_CASE(round_trip) { - std::stringstream out; BOOST_CHECK_EQUAL(back_and_forth_8('a'), 'a'); BOOST_CHECK_EQUAL(back_and_forth_16(1), 1); BOOST_CHECK_EQUAL(back_and_forth_16(12345), 12345); @@ -141,22 +145,28 @@ bool expect_bytes(std::stringstream &buf, std::initializer_list c BOOST_AUTO_TEST_CASE(expected) { std::stringstream buf; + auto it = std::ostream_iterator(buf); - serialize_int8(buf, 'a'); + serialize_int8(it, 'a'); BOOST_CHECK(expect_bytes(buf, {97})); - serialize_int32(buf, 1234567); + it = std::ostream_iterator(buf); + serialize_int32(it, 1234567); BOOST_CHECK(expect_bytes(buf, {0, 18, 214, 135})); - serialize_int16(buf, (uint16_t)12345); + it = std::ostream_iterator(buf); + serialize_int16(it, (uint16_t)12345); BOOST_CHECK(expect_bytes(buf, {48, 57})); - serialize_int64(buf, 1234567890123UL); + it = std::ostream_iterator(buf); + serialize_int64(it, 1234567890123UL); BOOST_CHECK(expect_bytes(buf, {0, 0, 1, 31, 113, 251, 4, 203})); - serialize_string(buf, "hello"); + it = std::ostream_iterator(buf); + serialize_string(it, "hello"); BOOST_CHECK(expect_bytes(buf, {0, 5, 104, 101, 108, 108, 111})); - serialize_string(buf, sstring("hello")); + it = std::ostream_iterator(buf); + serialize_string(it, sstring("hello")); BOOST_CHECK(expect_bytes(buf, {0, 5, 104, 101, 108, 108, 111})); } diff --git a/utils/serialization.hh b/utils/serialization.hh index 752ab07fe2..382492ed10 100644 --- a/utils/serialization.hh +++ b/utils/serialization.hh @@ -40,36 +40,65 @@ #include +#include #include "core/sstring.hh" #include "net/byteorder.hh" #include "bytes.hh" #include +#include class UTFDataFormatException { }; class EOFException { }; -inline -void serialize_int8(std::ostream& out, uint8_t val) { - out.put(val); -} +static constexpr size_t serialize_int8_size = 1; +static constexpr size_t serialize_bool_size = 1; +static constexpr size_t serialize_int16_size = 2; +static constexpr size_t serialize_int32_size = 4; +static constexpr size_t serialize_int64_size = 8; -inline -void serialize_int8(std::ostream& out, int8_t val) { - out.put(val); -} +namespace internal { +template +GCC6_CONCEPT(requires std::is_integral::value && std::is_integral::value && requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -void serialize_int8(bytes::iterator& out, uint8_t val) { - uint8_t nval = net::hton(val); +void serialize_int(CharOutputIterator& out, IntegerType val) { + ExplicitIntegerType nval = net::hton(ExplicitIntegerType(val)); out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); } -static constexpr size_t serialize_int8_size = 1; +} +template inline -void serialize_int8(std::ostream& out, char val) { - out.put(val); +void serialize_int8(CharOutputIterator& out, uint8_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int16(CharOutputIterator& out, uint16_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int32(CharOutputIterator& out, uint32_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int64(CharOutputIterator& out, uint64_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_bool(CharOutputIterator& out, bool val) { + serialize_int8(out, val ? 1 : 0); } inline @@ -82,44 +111,9 @@ int8_t deserialize_int8(std::istream& in) { } } -inline -void serialize_bool(std::ostream& out, bool b) { - out.put(b ? (char)1 : (char)0); -} - -static constexpr size_t serialize_bool_size = 1; - -inline -void serialize_bool(bytes::iterator& out, bool val) { - serialize_int8(out, val ? 1 : 0); -} - inline bool deserialize_bool(std::istream& in) { - char ret; - if (in.get(ret)) { - return ret; - } else { - throw EOFException(); - } -} - - -inline -void serialize_int16(std::ostream& out, uint16_t val) { - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int16(std::ostream& out, int16_t val) { - serialize_int16(out, (uint16_t) val); -} - -inline -void serialize_int16(bytes::iterator& out, uint16_t val) { - uint16_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); + return deserialize_int8(in); } inline @@ -133,29 +127,6 @@ int16_t deserialize_int16(std::istream& in) { return ((int16_t)(uint8_t)a1 << 8) | ((int16_t)(uint8_t)a2 << 0); } -static constexpr size_t serialize_int16_size = 2; - -inline -void serialize_int32(std::ostream& out, uint32_t val) { - out.put((char)((val >> 24) & 0xFF)); - out.put((char)((val >> 16) & 0xFF)); - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int32(std::ostream& out, int32_t val) { - serialize_int32(out, (uint32_t) val); -} - -inline -void serialize_int32(bytes::iterator& out, uint32_t val) { - uint32_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); -} - -static constexpr size_t serialize_int32_size = 4; - inline int32_t deserialize_int32(std::istream& in) { char a1, a2, a3, a4; @@ -169,31 +140,6 @@ int32_t deserialize_int32(std::istream& in) { ((int32_t)(uint8_t)a4 << 0); } -inline -void serialize_int64(std::ostream& out, uint64_t val) { - out.put((char)((val >> 56) & 0xFF)); - out.put((char)((val >> 48) & 0xFF)); - out.put((char)((val >> 40) & 0xFF)); - out.put((char)((val >> 32) & 0xFF)); - out.put((char)((val >> 24) & 0xFF)); - out.put((char)((val >> 16) & 0xFF)); - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int64(std::ostream& out, int64_t val) { - serialize_int64(out, (uint64_t) val); -} - -inline -void serialize_int64(bytes::iterator& out, uint64_t val) { - uint64_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); -} - -static constexpr size_t serialize_int64_size = 8; - inline int64_t deserialize_int64(std::istream& in) { char a1, a2, a3, a4, a5, a6, a7, a8; @@ -222,8 +168,12 @@ int64_t deserialize_int64(std::istream& in) { // http://docs.oracle.com/javase/7/docs/api/java/io/DataInput.html#modified-utf-8) // For now we'll just assume those aren't in the string... // TODO: fix the compatibility with Java even in this case. +template +GCC6_CONCEPT(requires requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -void serialize_string(std::ostream& out, const sstring& s) { +void serialize_string(CharOutputIterator& out, const sstring& s) { // Java specifies that nulls in the string need to be replaced by the // two bytes 0xC0, 0x80. Let's not bother with such transformation // now, but just verify wasn't needed. @@ -237,33 +187,16 @@ void serialize_string(std::ostream& out, const sstring& s) { // can't serialize longer strings. throw UTFDataFormatException(); } - serialize_int16(out, (uint16_t) s.size()); - out.write(s.c_str(), s.size()); -} - -inline -void serialize_string(bytes::iterator& out, const sstring& s) { - for (char c : s) { - if (c == '\0') { - throw UTFDataFormatException(); - } - } - if (s.size() > std::numeric_limits::max()) { - throw UTFDataFormatException(); - } - serialize_int16(out, (uint16_t) s.size()); + serialize_int16(out, s.size()); out = std::copy(s.begin(), s.end(), out); } +template +GCC6_CONCEPT(requires requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -size_t serialize_string_size(const sstring& s) {; - // As above, this code is missing the case of modified utf-8 - return serialize_int16_size + s.size(); -} - - -inline -void serialize_string(std::ostream& out, const char *s) { +void serialize_string(CharOutputIterator& out, const char* s) { // TODO: like above, need to change UTF-8 when above 16-bit. auto len = strlen(s); if (len > std::numeric_limits::max()) { @@ -271,8 +204,14 @@ void serialize_string(std::ostream& out, const char *s) { // can't serialize longer strings. throw UTFDataFormatException(); } - serialize_int16(out, (uint16_t) len); - out.write(s, len); + serialize_int16(out, len); + out = std::copy_n(s, len, out); +} + +inline +size_t serialize_string_size(const sstring& s) {; + // As above, this code is missing the case of modified utf-8 + return serialize_int16_size + s.size(); } inline From 494ea82a88429b4c9616b8903435af842108ec9d Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Sat, 13 May 2017 12:24:54 -0400 Subject: [PATCH 4/5] utils::UUID: align the UUID serialization API with the similar API of other classes in the project The standard serialization API (e.g. in data_value) includes the following methods: size_t serialized_size() const; void serialize(bytes::iterator& it) const; bytes serialize() const; Align the utils::UUID API with the pattern above. The only addition is that we are going to make an output iterator parameter of a second method above a template so that we may serialize into different output sources. Signed-off-by: Vlad Zolotarov --- tests/mutation_source_test.cc | 2 +- tests/types_test.cc | 8 ++++---- transport/server.cc | 5 +++-- types.cc | 9 ++++----- utils/UUID.hh | 17 +++++++++++++---- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/mutation_source_test.cc b/tests/mutation_source_test.cc index c98d9e8e44..7f278940bd 100644 --- a/tests/mutation_source_test.cc +++ b/tests/mutation_source_test.cc @@ -1083,7 +1083,7 @@ public: std::unordered_set unique_cells; unique_cells.reserve(num_cells); for (auto i = 0; i < num_cells; ++i) { - auto uuid = utils::UUID_gen::min_time_UUID(uuid_ts_dist(_gen)).to_bytes(); + auto uuid = utils::UUID_gen::min_time_UUID(uuid_ts_dist(_gen)).serialize(); if (unique_cells.emplace(uuid).second) { m.cells.emplace_back( bytes(reinterpret_cast(uuid.data()), uuid.size()), diff --git a/tests/types_test.cc b/tests/types_test.cc index bcf6572e0f..3d1a628dc2 100644 --- a/tests/types_test.cc +++ b/tests/types_test.cc @@ -539,17 +539,17 @@ BOOST_AUTO_TEST_CASE(test_long_type_validation) { BOOST_AUTO_TEST_CASE(test_timeuuid_type_validation) { auto now = utils::UUID_gen::get_time_UUID(); - timeuuid_type->validate(now.to_bytes()); + timeuuid_type->validate(now.serialize()); auto random = utils::make_random_uuid(); - test_validation_fails(timeuuid_type, random.to_bytes()); + test_validation_fails(timeuuid_type, random.serialize()); test_validation_fails(timeuuid_type, from_hex("00")); } BOOST_AUTO_TEST_CASE(test_uuid_type_validation) { auto now = utils::UUID_gen::get_time_UUID(); - uuid_type->validate(now.to_bytes()); + uuid_type->validate(now.serialize()); auto random = utils::make_random_uuid(); - uuid_type->validate(random.to_bytes()); + uuid_type->validate(random.serialize()); test_validation_fails(uuid_type, from_hex("00")); } diff --git a/transport/server.cc b/transport/server.cc index d25bbf9ebc..23b82521a3 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -242,7 +242,7 @@ private: // If tracing was requested the response should contain a "tracing // session ID" which is a 16 bytes UUID. if (_tracing_id) { - extra_len += 16; + extra_len += utils::UUID::serialized_size(); set_frame_flag(cql_frame_flags::tracing); } @@ -256,7 +256,8 @@ private: // Tracing session ID should be the first thing in the responce "body". if (_tracing_id) { - std::memcpy(frame_buf.data() + sizeof(CqlFrameHeaderType), _tracing_id->to_bytes().data(), 16); + auto out = frame_buf.data() + sizeof(CqlFrameHeaderType); + _tracing_id->serialize(out); } return frame_buf; diff --git a/types.cc b/types.cc index 81a7199b3f..a51326ab49 100644 --- a/types.cc +++ b/types.cc @@ -520,7 +520,7 @@ struct timeuuid_type_impl : public concrete_type { return; } auto uuid = uuid1.get(); - out = std::copy_n(uuid.to_bytes().begin(), sizeof(uuid), out); + uuid.serialize(out); } virtual size_t serialized_size(const void* value) const override { if (!value || from_value(value).empty()) { @@ -583,7 +583,7 @@ struct timeuuid_type_impl : public concrete_type { if (v.version() != 1) { throw marshal_exception(); } - return v.to_bytes(); + return v.serialize(); } virtual sstring to_string(const bytes& b) const override { auto v = deserialize(b); @@ -972,8 +972,7 @@ struct uuid_type_impl : concrete_type { if (!value) { return; } - auto& uuid = from_value(value); - out = std::copy_n(uuid.get().to_bytes().begin(), sizeof(uuid.get()), out); + from_value(value).get().serialize(out); } virtual size_t serialized_size(const void* value) const override { if (!value) { @@ -1035,7 +1034,7 @@ struct uuid_type_impl : concrete_type { throw marshal_exception(); } utils::UUID v(s); - return v.to_bytes(); + return v.serialize(); } virtual sstring to_string(const bytes& b) const override { auto v = deserialize(b); diff --git a/utils/UUID.hh b/utils/UUID.hh index 96ffed59df..f12bc86ecc 100644 --- a/utils/UUID.hh +++ b/utils/UUID.hh @@ -113,13 +113,22 @@ public: return !(*this < v); } - bytes to_bytes() const { - bytes b(bytes::initialized_later(),16); + bytes serialize() const { + bytes b(bytes::initialized_later(), serialized_size()); auto i = b.begin(); - serialize_int64(i, most_sig_bits); - serialize_int64(i, least_sig_bits); + serialize(i); return b; } + + static size_t serialized_size() noexcept { + return 16; + } + + template + void serialize(CharOutputIterator& out) const { + serialize_int64(out, most_sig_bits); + serialize_int64(out, least_sig_bits); + } }; UUID make_random_uuid(); From a0737abdc5d9d5258595da5b7695bb66c6ce24de Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Sat, 13 May 2017 16:36:29 -0400 Subject: [PATCH 5/5] cql_server::response: rework the tracing session ID insertion Insert the tracing session ID into the response body in the cql_server::response constructor. Fixes #2356 Signed-off-by: Vlad Zolotarov --- service/client_state.hh | 4 -- tracing/trace_state.hh | 7 ++ transport/server.cc | 142 ++++++++++++++++++---------------------- transport/server.hh | 24 +++---- 4 files changed, 82 insertions(+), 95 deletions(-) diff --git a/service/client_state.hh b/service/client_state.hh index 30ad1959f7..e1a11b9d7a 100644 --- a/service/client_state.hh +++ b/service/client_state.hh @@ -122,10 +122,6 @@ public: return _trace_state_ptr; } - lw_shared_ptr& tracing_session_id_ptr() { - return _tracing_session_id; - } - client_state(external_tag, const socket_address& remote_address = socket_address(), bool thrift = false) : _is_internal(false) , _is_thrift(thrift) diff --git a/tracing/trace_state.hh b/tracing/trace_state.hh index da6fd22a05..839dfa0c5c 100644 --- a/tracing/trace_state.hh +++ b/tracing/trace_state.hh @@ -538,6 +538,13 @@ inline void add_table_name(const trace_state_ptr& p, const sstring& ks_name, con } } +inline bool should_return_id_in_response(const trace_state_ptr& p) { + if (p) { + return p->write_on_close(); + } + return false; +} + /** * A helper for conditional invoking trace_state::begin() functions. * diff --git a/transport/server.cc b/transport/server.cc index 23b82521a3..6f7d4baa0a 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -187,22 +187,24 @@ class cql_server::response { int16_t _stream; cql_binary_opcode _opcode; uint8_t _flags = 0; // a bitwise OR mask of zero or more cql_frame_flags values - std::experimental::optional _tracing_id; std::vector _body; public: - response(int16_t stream, cql_binary_opcode opcode) + response(int16_t stream, cql_binary_opcode opcode, const tracing::trace_state_ptr& tr_state_ptr) : _stream{stream} , _opcode{opcode} - { } + , _body(tracing::should_return_id_in_response(tr_state_ptr) ? utils::UUID::serialized_size() : 0) + { + if (tracing::should_return_id_in_response(tr_state_ptr)) { + auto i = _body.begin(); + tr_state_ptr->session_id().serialize(i); + set_frame_flag(cql_frame_flags::tracing); + } + } void set_frame_flag(cql_frame_flags flag) noexcept { _flags |= flag; } - void set_tracing_id(const utils::UUID& id) { - _tracing_id = id; - } - scattered_message make_message(uint8_t version); void serialize(const event::schema_change& event, uint8_t version); void write_byte(uint8_t b); @@ -237,29 +239,14 @@ private: template sstring make_frame_one(uint8_t version, size_t length) { - size_t extra_len = 0; - - // If tracing was requested the response should contain a "tracing - // session ID" which is a 16 bytes UUID. - if (_tracing_id) { - extra_len += utils::UUID::serialized_size(); - set_frame_flag(cql_frame_flags::tracing); - } - - sstring frame_buf(sstring::initialized_later(), sizeof(CqlFrameHeaderType) + extra_len); + sstring frame_buf(sstring::initialized_later(), sizeof(CqlFrameHeaderType)); auto* frame = reinterpret_cast(frame_buf.begin()); frame->version = version | 0x80; frame->flags = _flags; frame->opcode = static_cast(_opcode); - frame->length = htonl(length + extra_len); + frame->length = htonl(length); frame->stream = net::hton((decltype(frame->stream))_stream); - // Tracing session ID should be the first thing in the responce "body". - if (_tracing_id) { - auto out = frame_buf.data() + sizeof(CqlFrameHeaderType); - _tracing_id->serialize(out); - } - return frame_buf; } @@ -547,21 +534,21 @@ future } return make_ready_future(response); } catch (const exceptions::unavailable_exception& ex) { - return make_ready_future(std::make_pair(make_unavailable_error(stream, ex.code(), ex.what(), ex.consistency, ex.required, ex.alive), client_state)); + return make_ready_future(std::make_pair(make_unavailable_error(stream, ex.code(), ex.what(), ex.consistency, ex.required, ex.alive, client_state.get_trace_state()), client_state)); } catch (const exceptions::read_timeout_exception& ex) { - return make_ready_future(std::make_pair(make_read_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.data_present), client_state)); + return make_ready_future(std::make_pair(make_read_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.data_present, client_state.get_trace_state()), client_state)); } catch (const exceptions::mutation_write_timeout_exception& ex) { - return make_ready_future(std::make_pair(make_mutation_write_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.type), client_state)); + return make_ready_future(std::make_pair(make_mutation_write_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.type, client_state.get_trace_state()), client_state)); } catch (const exceptions::already_exists_exception& ex) { - return make_ready_future(std::make_pair(make_already_exists_error(stream, ex.code(), ex.what(), ex.ks_name, ex.cf_name), client_state)); + return make_ready_future(std::make_pair(make_already_exists_error(stream, ex.code(), ex.what(), ex.ks_name, ex.cf_name, client_state.get_trace_state()), client_state)); } catch (const exceptions::prepared_query_not_found_exception& ex) { - return make_ready_future(std::make_pair(make_unprepared_error(stream, ex.code(), ex.what(), ex.id), client_state)); + return make_ready_future(std::make_pair(make_unprepared_error(stream, ex.code(), ex.what(), ex.id, client_state.get_trace_state()), client_state)); } catch (const exceptions::cassandra_exception& ex) { - return make_ready_future(std::make_pair(make_error(stream, ex.code(), ex.what()), client_state)); + return make_ready_future(std::make_pair(make_error(stream, ex.code(), ex.what(), client_state.get_trace_state()), client_state)); } catch (std::exception& ex) { - return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, ex.what()), client_state)); + return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, ex.what(), client_state.get_trace_state()), client_state)); } catch (...) { - return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, "unknown error"), client_state)); + return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, "unknown error", client_state.get_trace_state()), client_state)); } }).finally([tracing_state = client_state.get_trace_state()] { tracing::stop_foreground(tracing_state); @@ -598,11 +585,11 @@ future<> cql_server::connection::process() f.get(); return make_ready_future<>(); } catch (const exceptions::cassandra_exception& ex) { - return write_response(make_error(0, ex.code(), ex.what())); + return write_response(make_error(0, ex.code(), ex.what(), tracing::trace_state_ptr())); } catch (std::exception& ex) { - return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, ex.what())); + return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, ex.what(), tracing::trace_state_ptr())); } catch (...) { - return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error")); + return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error", tracing::trace_state_ptr())); } }).finally([this] { _server._notifier->unregister_connection(this); @@ -667,12 +654,7 @@ future<> cql_server::connection::process_request() { auto bv = bytes_view{reinterpret_cast(buf.begin()), buf.size()}; auto cpu = pick_request_cpu(); return smp::submit_to(cpu, [this, bv = std::move(bv), op, stream, client_state = _client_state, tracing_requested] () mutable { - return process_request_stage(this, bv, op, stream, std::move(client_state), tracing_requested).then([tracing_requested](auto&& response) { - auto& tracing_session_id_ptr = response.second.tracing_session_id_ptr(); - // report a tracing session ID only if it was explicitly requested to trace this particular query - if (tracing_requested == tracing_request_type::write_on_close && tracing_session_id_ptr) { - response.first->set_tracing_id(*tracing_session_id_ptr); - } + return process_request_stage(this, bv, op, stream, std::move(client_state), tracing_requested).then([] (auto&& response) { return std::make_pair(make_foreign(response.first), response.second); }); }).then([this, flags] (auto&& response) { @@ -770,9 +752,9 @@ future cql_server::connection::process_startup(uint16_t stream, b } auto& a = auth::authenticator::get(); if (a.require_authentication()) { - return make_ready_future(std::make_pair(make_autheticate(stream, a.class_name()), client_state)); + return make_ready_future(std::make_pair(make_autheticate(stream, a.class_name(), client_state.get_trace_state()), client_state)); } - return make_ready_future(std::make_pair(make_ready(stream), client_state)); + return make_ready_future(std::make_pair(make_ready(stream, client_state.get_trace_state()), client_state)); } future cql_server::connection::process_auth_response(uint16_t stream, bytes_view buf, service::client_state client_state) @@ -787,16 +769,18 @@ future cql_server::connection::process_auth_response(uint16_t str client_state.set_login(std::move(user)); auto f = client_state.check_user_exists(); return f.then([this, stream, client_state = std::move(client_state), challenge = std::move(challenge)]() mutable { - return make_ready_future(std::make_pair(make_auth_success(stream, std::move(challenge)), std::move(client_state))); + auto tr_state = client_state.get_trace_state(); + return make_ready_future(std::make_pair(make_auth_success(stream, std::move(challenge), tr_state), std::move(client_state))); }); }); } - return make_ready_future(std::make_pair(make_auth_challenge(stream, std::move(challenge)), std::move(client_state))); + auto tr_state = client_state.get_trace_state(); + return make_ready_future(std::make_pair(make_auth_challenge(stream, std::move(challenge), tr_state), std::move(client_state))); } future cql_server::connection::process_options(uint16_t stream, bytes_view buf, service::client_state client_state) { - return make_ready_future(std::make_pair(make_supported(stream), client_state)); + return make_ready_future(std::make_pair(make_supported(stream, client_state.get_trace_state()), client_state)); } void @@ -823,7 +807,7 @@ future cql_server::connection::process_query(uint16_t stream, byt return _server._query_processor.local().process(query, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); - return this->make_result(stream, msg, skip_metadata); + return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -856,7 +840,7 @@ future cql_server::connection::process_prepare(uint16_t stream, b tracing::trace(cs.get_trace_state(), "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] { return messages::result_message::prepared::cql::get_id(msg); })); - return this->make_result(stream, msg); + return this->make_result(stream, msg, cs.get_trace_state()); }); }).then([client_state = std::move(client_state)] (auto&& response) { /* keep client_state alive */ @@ -904,7 +888,7 @@ future cql_server::connection::process_execute(uint16_t stream, b tracing::trace(query_state.get_trace_state(), "Processing a statement"); return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); - return this->make_result(stream, msg, skip_metadata); + return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -987,8 +971,8 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: tracing::trace(client_state.get_trace_state(), "Creating a batch statement"); auto batch = ::make_shared(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats()); - return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch] (auto msg) { - return this->make_result(stream, msg); + return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) { + return this->make_result(stream, msg, query_state.get_trace_state()); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -1004,12 +988,12 @@ cql_server::connection::process_register(uint16_t stream, bytes_view buf, servic auto et = parse_event_type(event_type); _server._notifier->register_event(et, this); } - return make_ready_future(std::make_pair(make_ready(stream), client_state)); + return make_ready_future(std::make_pair(make_ready(stream, client_state.get_trace_state()), client_state)); } -shared_ptr cql_server::connection::make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive) +shared_ptr cql_server::connection::make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1018,9 +1002,9 @@ shared_ptr cql_server::connection::make_unavailable_error( return response; } -shared_ptr cql_server::connection::make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present) +shared_ptr cql_server::connection::make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1030,9 +1014,9 @@ shared_ptr cql_server::connection::make_read_timeout_error return response; } -shared_ptr cql_server::connection::make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type) +shared_ptr cql_server::connection::make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1042,9 +1026,9 @@ shared_ptr cql_server::connection::make_mutation_write_tim return response; } -shared_ptr cql_server::connection::make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name) +shared_ptr cql_server::connection::make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_string(ks_name); @@ -1052,54 +1036,54 @@ shared_ptr cql_server::connection::make_already_exists_err return response; } -shared_ptr cql_server::connection::make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id) +shared_ptr cql_server::connection::make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_short_bytes(id); return response; } -shared_ptr cql_server::connection::make_error(int16_t stream, exceptions::exception_code err, sstring msg) +shared_ptr cql_server::connection::make_error(int16_t stream, exceptions::exception_code err, sstring msg, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); return response; } -shared_ptr cql_server::connection::make_ready(int16_t stream) +shared_ptr cql_server::connection::make_ready(int16_t stream, const tracing::trace_state_ptr& tr_state) { - return make_shared(stream, cql_binary_opcode::READY); + return make_shared(stream, cql_binary_opcode::READY, tr_state); } -shared_ptr cql_server::connection::make_autheticate(int16_t stream, const sstring& clz) +shared_ptr cql_server::connection::make_autheticate(int16_t stream, const sstring& clz, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::AUTHENTICATE); + auto response = make_shared(stream, cql_binary_opcode::AUTHENTICATE, tr_state); response->write_string(clz); return response; } -shared_ptr cql_server::connection::make_auth_success(int16_t stream, bytes b) { - auto response = make_shared(stream, cql_binary_opcode::AUTH_SUCCESS); +shared_ptr cql_server::connection::make_auth_success(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) { + auto response = make_shared(stream, cql_binary_opcode::AUTH_SUCCESS, tr_state); response->write_bytes(std::move(b)); return response; } -shared_ptr cql_server::connection::make_auth_challenge(int16_t stream, bytes b) { - auto response = make_shared(stream, cql_binary_opcode::AUTH_CHALLENGE); +shared_ptr cql_server::connection::make_auth_challenge(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) { + auto response = make_shared(stream, cql_binary_opcode::AUTH_CHALLENGE, tr_state); response->write_bytes(std::move(b)); return response; } -shared_ptr cql_server::connection::make_supported(int16_t stream) +shared_ptr cql_server::connection::make_supported(int16_t stream, const tracing::trace_state_ptr& tr_state) { std::multimap opts; opts.insert({"CQL_VERSION", cql3::query_processor::CQL_VERSION}); opts.insert({"COMPRESSION", "lz4"}); opts.insert({"COMPRESSION", "snappy"}); - auto response = make_shared(stream, cql_binary_opcode::SUPPORTED); + auto response = make_shared(stream, cql_binary_opcode::SUPPORTED, tr_state); response->write_string_multimap(opts); return response; } @@ -1162,9 +1146,9 @@ public: }; shared_ptr -cql_server::connection::make_result(int16_t stream, shared_ptr msg, bool skip_metadata) +cql_server::connection::make_result(int16_t stream, shared_ptr msg, const tracing::trace_state_ptr& tr_state, bool skip_metadata) { - auto response = make_shared(stream, cql_binary_opcode::RESULT); + auto response = make_shared(stream, cql_binary_opcode::RESULT, tr_state); fmt_visitor fmt{_version, response, skip_metadata}; msg->accept(fmt); return response; @@ -1173,7 +1157,7 @@ cql_server::connection::make_result(int16_t stream, shared_ptr cql_server::connection::make_topology_change_event(const event::topology_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("TOPOLOGY_CHANGE"); response->write_string(to_string(event.change)); response->write_inet(event.node); @@ -1183,7 +1167,7 @@ cql_server::connection::make_topology_change_event(const event::topology_change& shared_ptr cql_server::connection::make_status_change_event(const event::status_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("STATUS_CHANGE"); response->write_string(to_string(event.status)); response->write_inet(event.node); @@ -1193,7 +1177,7 @@ cql_server::connection::make_status_change_event(const event::status_change& eve shared_ptr cql_server::connection::make_schema_change_event(const event::schema_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("SCHEMA_CHANGE"); response->serialize(event, _version); return response; diff --git a/transport/server.hh b/transport/server.hh index 603cd927fb..b47ceeaf0a 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -177,21 +177,21 @@ private: future process_batch(uint16_t stream, bytes_view buf, service::client_state client_state); future process_register(uint16_t stream, bytes_view buf, service::client_state client_state); - shared_ptr make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive); - shared_ptr make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present); - shared_ptr make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type); - shared_ptr make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name); - shared_ptr make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id); - shared_ptr make_error(int16_t stream, exceptions::exception_code err, sstring msg); - shared_ptr make_ready(int16_t stream); - shared_ptr make_supported(int16_t stream); - shared_ptr make_result(int16_t stream, shared_ptr msg, bool skip_metadata = false); + shared_ptr make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive, const tracing::trace_state_ptr& tr_state); + shared_ptr make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state); + shared_ptr make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state); + shared_ptr make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name, const tracing::trace_state_ptr& tr_state); + shared_ptr make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id, const tracing::trace_state_ptr& tr_state); + shared_ptr make_error(int16_t stream, exceptions::exception_code err, sstring msg, const tracing::trace_state_ptr& tr_state); + shared_ptr make_ready(int16_t stream, const tracing::trace_state_ptr& tr_state); + shared_ptr make_supported(int16_t stream, const tracing::trace_state_ptr& tr_state); + shared_ptr make_result(int16_t stream, shared_ptr msg, const tracing::trace_state_ptr& tr_state, bool skip_metadata = false); shared_ptr make_topology_change_event(const transport::event::topology_change& event); shared_ptr make_status_change_event(const transport::event::status_change& event); shared_ptr make_schema_change_event(const transport::event::schema_change& event); - shared_ptr make_autheticate(int16_t, const sstring&); - shared_ptr make_auth_success(int16_t, bytes); - shared_ptr make_auth_challenge(int16_t, bytes); + shared_ptr make_autheticate(int16_t, const sstring&, const tracing::trace_state_ptr& tr_state); + shared_ptr make_auth_success(int16_t, bytes, const tracing::trace_state_ptr& tr_state); + shared_ptr make_auth_challenge(int16_t, bytes, const tracing::trace_state_ptr& tr_state); future<> write_response(foreign_ptr>&& response, cql_compression compression = cql_compression::none);