diff --git a/service/client_state.hh b/service/client_state.hh index 30ad1959f7..e1a11b9d7a 100644 --- a/service/client_state.hh +++ b/service/client_state.hh @@ -122,10 +122,6 @@ public: return _trace_state_ptr; } - lw_shared_ptr& tracing_session_id_ptr() { - return _tracing_session_id; - } - client_state(external_tag, const socket_address& remote_address = socket_address(), bool thrift = false) : _is_internal(false) , _is_thrift(thrift) diff --git a/tests/mutation_source_test.cc b/tests/mutation_source_test.cc index c98d9e8e44..7f278940bd 100644 --- a/tests/mutation_source_test.cc +++ b/tests/mutation_source_test.cc @@ -1083,7 +1083,7 @@ public: std::unordered_set unique_cells; unique_cells.reserve(num_cells); for (auto i = 0; i < num_cells; ++i) { - auto uuid = utils::UUID_gen::min_time_UUID(uuid_ts_dist(_gen)).to_bytes(); + auto uuid = utils::UUID_gen::min_time_UUID(uuid_ts_dist(_gen)).serialize(); if (unique_cells.emplace(uuid).second) { m.cells.emplace_back( bytes(reinterpret_cast(uuid.data()), uuid.size()), diff --git a/tests/test-serialization.cc b/tests/test-serialization.cc index ba99577aaa..2a2fa915e4 100644 --- a/tests/test-serialization.cc +++ b/tests/test-serialization.cc @@ -49,31 +49,35 @@ void show(std::stringstream &ss) { // is rotten" type of translation :-) int8_t back_and_forth_8(int8_t a) { std::stringstream buf; - serialize_int8(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int8(it, a); return deserialize_int8(buf); } int16_t back_and_forth_16(int16_t a) { std::stringstream buf; - serialize_int16(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int16(it, a); return deserialize_int16(buf); } int32_t back_and_forth_32(int32_t a) { std::stringstream buf; - serialize_int32(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int32(it, a); return deserialize_int32(buf); } int64_t back_and_forth_64(int64_t a) { std::stringstream buf; - serialize_int64(buf, a); + auto it = std::ostream_iterator(buf); + serialize_int64(it, a); return deserialize_int64(buf); } sstring back_and_forth_sstring(sstring a) { std::stringstream buf; - serialize_string(buf, a); + auto it = std::ostream_iterator(buf); + serialize_string(it, a); return deserialize_string(buf); } BOOST_AUTO_TEST_CASE(round_trip) { - std::stringstream out; BOOST_CHECK_EQUAL(back_and_forth_8('a'), 'a'); BOOST_CHECK_EQUAL(back_and_forth_16(1), 1); BOOST_CHECK_EQUAL(back_and_forth_16(12345), 12345); @@ -141,22 +145,28 @@ bool expect_bytes(std::stringstream &buf, std::initializer_list c BOOST_AUTO_TEST_CASE(expected) { std::stringstream buf; + auto it = std::ostream_iterator(buf); - serialize_int8(buf, 'a'); + serialize_int8(it, 'a'); BOOST_CHECK(expect_bytes(buf, {97})); - serialize_int32(buf, 1234567); + it = std::ostream_iterator(buf); + serialize_int32(it, 1234567); BOOST_CHECK(expect_bytes(buf, {0, 18, 214, 135})); - serialize_int16(buf, (uint16_t)12345); + it = std::ostream_iterator(buf); + serialize_int16(it, (uint16_t)12345); BOOST_CHECK(expect_bytes(buf, {48, 57})); - serialize_int64(buf, 1234567890123UL); + it = std::ostream_iterator(buf); + serialize_int64(it, 1234567890123UL); BOOST_CHECK(expect_bytes(buf, {0, 0, 1, 31, 113, 251, 4, 203})); - serialize_string(buf, "hello"); + it = std::ostream_iterator(buf); + serialize_string(it, "hello"); BOOST_CHECK(expect_bytes(buf, {0, 5, 104, 101, 108, 108, 111})); - serialize_string(buf, sstring("hello")); + it = std::ostream_iterator(buf); + serialize_string(it, sstring("hello")); BOOST_CHECK(expect_bytes(buf, {0, 5, 104, 101, 108, 108, 111})); } diff --git a/tests/types_test.cc b/tests/types_test.cc index bcf6572e0f..3d1a628dc2 100644 --- a/tests/types_test.cc +++ b/tests/types_test.cc @@ -539,17 +539,17 @@ BOOST_AUTO_TEST_CASE(test_long_type_validation) { BOOST_AUTO_TEST_CASE(test_timeuuid_type_validation) { auto now = utils::UUID_gen::get_time_UUID(); - timeuuid_type->validate(now.to_bytes()); + timeuuid_type->validate(now.serialize()); auto random = utils::make_random_uuid(); - test_validation_fails(timeuuid_type, random.to_bytes()); + test_validation_fails(timeuuid_type, random.serialize()); test_validation_fails(timeuuid_type, from_hex("00")); } BOOST_AUTO_TEST_CASE(test_uuid_type_validation) { auto now = utils::UUID_gen::get_time_UUID(); - uuid_type->validate(now.to_bytes()); + uuid_type->validate(now.serialize()); auto random = utils::make_random_uuid(); - uuid_type->validate(random.to_bytes()); + uuid_type->validate(random.serialize()); test_validation_fails(uuid_type, from_hex("00")); } diff --git a/tracing/trace_state.hh b/tracing/trace_state.hh index da6fd22a05..839dfa0c5c 100644 --- a/tracing/trace_state.hh +++ b/tracing/trace_state.hh @@ -538,6 +538,13 @@ inline void add_table_name(const trace_state_ptr& p, const sstring& ks_name, con } } +inline bool should_return_id_in_response(const trace_state_ptr& p) { + if (p) { + return p->write_on_close(); + } + return false; +} + /** * A helper for conditional invoking trace_state::begin() functions. * diff --git a/transport/server.cc b/transport/server.cc index 7b41de047e..6f7d4baa0a 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -186,16 +186,23 @@ cql_load_balance parse_load_balance(sstring value) class cql_server::response { int16_t _stream; cql_binary_opcode _opcode; - std::experimental::optional _tracing_id; + uint8_t _flags = 0; // a bitwise OR mask of zero or more cql_frame_flags values std::vector _body; public: - response(int16_t stream, cql_binary_opcode opcode) + response(int16_t stream, cql_binary_opcode opcode, const tracing::trace_state_ptr& tr_state_ptr) : _stream{stream} , _opcode{opcode} - { } + , _body(tracing::should_return_id_in_response(tr_state_ptr) ? utils::UUID::serialized_size() : 0) + { + if (tracing::should_return_id_in_response(tr_state_ptr)) { + auto i = _body.begin(); + tr_state_ptr->session_id().serialize(i); + set_frame_flag(cql_frame_flags::tracing); + } + } - void set_tracing_id(const utils::UUID& id) { - _tracing_id = id; + void set_frame_flag(cql_frame_flags flag) noexcept { + _flags |= flag; } scattered_message make_message(uint8_t version); @@ -226,46 +233,32 @@ public: return _opcode; } private: - std::vector compress(const std::vector& body, cql_compression compression); + void compress(cql_compression compression); std::vector compress_lz4(const std::vector& body); std::vector compress_snappy(const std::vector& body); template - sstring make_frame_one(uint8_t version, uint8_t flags, size_t length) { - size_t extra_len = 0; - - // If tracing was requested the response should contain a "tracing - // session ID" which is a 16 bytes UUID. - if (_tracing_id) { - extra_len += 16; - flags |= cql_frame_flags::tracing; - } - - sstring frame_buf(sstring::initialized_later(), sizeof(CqlFrameHeaderType) + extra_len); + sstring make_frame_one(uint8_t version, size_t length) { + sstring frame_buf(sstring::initialized_later(), sizeof(CqlFrameHeaderType)); auto* frame = reinterpret_cast(frame_buf.begin()); frame->version = version | 0x80; - frame->flags = flags; + frame->flags = _flags; frame->opcode = static_cast(_opcode); - frame->length = htonl(length + extra_len); + frame->length = htonl(length); frame->stream = net::hton((decltype(frame->stream))_stream); - // Tracing session ID should be the first thing in the responce "body". - if (_tracing_id) { - std::memcpy(frame_buf.data() + sizeof(CqlFrameHeaderType), _tracing_id->to_bytes().data(), 16); - } - return frame_buf; } - sstring make_frame(uint8_t version, uint8_t flags, size_t length) { + sstring make_frame(uint8_t version, size_t length) { if (version > 0x04) { throw exceptions::protocol_exception(sprint("Invalid or unsupported protocol version: %d", version)); } if (version > 0x02) { - return make_frame_one(version, flags, length); + return make_frame_one(version, length); } else { - return make_frame_one(version, flags, length); + return make_frame_one(version, length); } } }; @@ -541,21 +534,21 @@ future } return make_ready_future(response); } catch (const exceptions::unavailable_exception& ex) { - return make_ready_future(std::make_pair(make_unavailable_error(stream, ex.code(), ex.what(), ex.consistency, ex.required, ex.alive), client_state)); + return make_ready_future(std::make_pair(make_unavailable_error(stream, ex.code(), ex.what(), ex.consistency, ex.required, ex.alive, client_state.get_trace_state()), client_state)); } catch (const exceptions::read_timeout_exception& ex) { - return make_ready_future(std::make_pair(make_read_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.data_present), client_state)); + return make_ready_future(std::make_pair(make_read_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.data_present, client_state.get_trace_state()), client_state)); } catch (const exceptions::mutation_write_timeout_exception& ex) { - return make_ready_future(std::make_pair(make_mutation_write_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.type), client_state)); + return make_ready_future(std::make_pair(make_mutation_write_timeout_error(stream, ex.code(), ex.what(), ex.consistency, ex.received, ex.block_for, ex.type, client_state.get_trace_state()), client_state)); } catch (const exceptions::already_exists_exception& ex) { - return make_ready_future(std::make_pair(make_already_exists_error(stream, ex.code(), ex.what(), ex.ks_name, ex.cf_name), client_state)); + return make_ready_future(std::make_pair(make_already_exists_error(stream, ex.code(), ex.what(), ex.ks_name, ex.cf_name, client_state.get_trace_state()), client_state)); } catch (const exceptions::prepared_query_not_found_exception& ex) { - return make_ready_future(std::make_pair(make_unprepared_error(stream, ex.code(), ex.what(), ex.id), client_state)); + return make_ready_future(std::make_pair(make_unprepared_error(stream, ex.code(), ex.what(), ex.id, client_state.get_trace_state()), client_state)); } catch (const exceptions::cassandra_exception& ex) { - return make_ready_future(std::make_pair(make_error(stream, ex.code(), ex.what()), client_state)); + return make_ready_future(std::make_pair(make_error(stream, ex.code(), ex.what(), client_state.get_trace_state()), client_state)); } catch (std::exception& ex) { - return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, ex.what()), client_state)); + return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, ex.what(), client_state.get_trace_state()), client_state)); } catch (...) { - return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, "unknown error"), client_state)); + return make_ready_future(std::make_pair(make_error(stream, exceptions::exception_code::SERVER_ERROR, "unknown error", client_state.get_trace_state()), client_state)); } }).finally([tracing_state = client_state.get_trace_state()] { tracing::stop_foreground(tracing_state); @@ -592,11 +585,11 @@ future<> cql_server::connection::process() f.get(); return make_ready_future<>(); } catch (const exceptions::cassandra_exception& ex) { - return write_response(make_error(0, ex.code(), ex.what())); + return write_response(make_error(0, ex.code(), ex.what(), tracing::trace_state_ptr())); } catch (std::exception& ex) { - return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, ex.what())); + return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, ex.what(), tracing::trace_state_ptr())); } catch (...) { - return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error")); + return write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error", tracing::trace_state_ptr())); } }).finally([this] { _server._notifier->unregister_connection(this); @@ -661,12 +654,7 @@ future<> cql_server::connection::process_request() { auto bv = bytes_view{reinterpret_cast(buf.begin()), buf.size()}; auto cpu = pick_request_cpu(); return smp::submit_to(cpu, [this, bv = std::move(bv), op, stream, client_state = _client_state, tracing_requested] () mutable { - return process_request_stage(this, bv, op, stream, std::move(client_state), tracing_requested).then([tracing_requested](auto&& response) { - auto& tracing_session_id_ptr = response.second.tracing_session_id_ptr(); - // report a tracing session ID only if it was explicitly requested to trace this particular query - if (tracing_requested == tracing_request_type::write_on_close && tracing_session_id_ptr) { - response.first->set_tracing_id(*tracing_session_id_ptr); - } + return process_request_stage(this, bv, op, stream, std::move(client_state), tracing_requested).then([] (auto&& response) { return std::make_pair(make_foreign(response.first), response.second); }); }).then([this, flags] (auto&& response) { @@ -764,9 +752,9 @@ future cql_server::connection::process_startup(uint16_t stream, b } auto& a = auth::authenticator::get(); if (a.require_authentication()) { - return make_ready_future(std::make_pair(make_autheticate(stream, a.class_name()), client_state)); + return make_ready_future(std::make_pair(make_autheticate(stream, a.class_name(), client_state.get_trace_state()), client_state)); } - return make_ready_future(std::make_pair(make_ready(stream), client_state)); + return make_ready_future(std::make_pair(make_ready(stream, client_state.get_trace_state()), client_state)); } future cql_server::connection::process_auth_response(uint16_t stream, bytes_view buf, service::client_state client_state) @@ -781,16 +769,18 @@ future cql_server::connection::process_auth_response(uint16_t str client_state.set_login(std::move(user)); auto f = client_state.check_user_exists(); return f.then([this, stream, client_state = std::move(client_state), challenge = std::move(challenge)]() mutable { - return make_ready_future(std::make_pair(make_auth_success(stream, std::move(challenge)), std::move(client_state))); + auto tr_state = client_state.get_trace_state(); + return make_ready_future(std::make_pair(make_auth_success(stream, std::move(challenge), tr_state), std::move(client_state))); }); }); } - return make_ready_future(std::make_pair(make_auth_challenge(stream, std::move(challenge)), std::move(client_state))); + auto tr_state = client_state.get_trace_state(); + return make_ready_future(std::make_pair(make_auth_challenge(stream, std::move(challenge), tr_state), std::move(client_state))); } future cql_server::connection::process_options(uint16_t stream, bytes_view buf, service::client_state client_state) { - return make_ready_future(std::make_pair(make_supported(stream), client_state)); + return make_ready_future(std::make_pair(make_supported(stream, client_state.get_trace_state()), client_state)); } void @@ -817,7 +807,7 @@ future cql_server::connection::process_query(uint16_t stream, byt return _server._query_processor.local().process(query, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); - return this->make_result(stream, msg, skip_metadata); + return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -850,7 +840,7 @@ future cql_server::connection::process_prepare(uint16_t stream, b tracing::trace(cs.get_trace_state(), "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] { return messages::result_message::prepared::cql::get_id(msg); })); - return this->make_result(stream, msg); + return this->make_result(stream, msg, cs.get_trace_state()); }); }).then([client_state = std::move(client_state)] (auto&& response) { /* keep client_state alive */ @@ -898,7 +888,7 @@ future cql_server::connection::process_execute(uint16_t stream, b tracing::trace(query_state.get_trace_state(), "Processing a statement"); return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); - return this->make_result(stream, msg, skip_metadata); + return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -981,8 +971,8 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: tracing::trace(client_state.get_trace_state(), "Creating a batch statement"); auto batch = ::make_shared(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats()); - return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch] (auto msg) { - return this->make_result(stream, msg); + return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) { + return this->make_result(stream, msg, query_state.get_trace_state()); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ return make_ready_future(std::make_pair(response, query_state.get_client_state())); @@ -998,12 +988,12 @@ cql_server::connection::process_register(uint16_t stream, bytes_view buf, servic auto et = parse_event_type(event_type); _server._notifier->register_event(et, this); } - return make_ready_future(std::make_pair(make_ready(stream), client_state)); + return make_ready_future(std::make_pair(make_ready(stream, client_state.get_trace_state()), client_state)); } -shared_ptr cql_server::connection::make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive) +shared_ptr cql_server::connection::make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1012,9 +1002,9 @@ shared_ptr cql_server::connection::make_unavailable_error( return response; } -shared_ptr cql_server::connection::make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present) +shared_ptr cql_server::connection::make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1024,9 +1014,9 @@ shared_ptr cql_server::connection::make_read_timeout_error return response; } -shared_ptr cql_server::connection::make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type) +shared_ptr cql_server::connection::make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_consistency(cl); @@ -1036,9 +1026,9 @@ shared_ptr cql_server::connection::make_mutation_write_tim return response; } -shared_ptr cql_server::connection::make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name) +shared_ptr cql_server::connection::make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_string(ks_name); @@ -1046,54 +1036,54 @@ shared_ptr cql_server::connection::make_already_exists_err return response; } -shared_ptr cql_server::connection::make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id) +shared_ptr cql_server::connection::make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); response->write_short_bytes(id); return response; } -shared_ptr cql_server::connection::make_error(int16_t stream, exceptions::exception_code err, sstring msg) +shared_ptr cql_server::connection::make_error(int16_t stream, exceptions::exception_code err, sstring msg, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::ERROR); + auto response = make_shared(stream, cql_binary_opcode::ERROR, tr_state); response->write_int(static_cast(err)); response->write_string(msg); return response; } -shared_ptr cql_server::connection::make_ready(int16_t stream) +shared_ptr cql_server::connection::make_ready(int16_t stream, const tracing::trace_state_ptr& tr_state) { - return make_shared(stream, cql_binary_opcode::READY); + return make_shared(stream, cql_binary_opcode::READY, tr_state); } -shared_ptr cql_server::connection::make_autheticate(int16_t stream, const sstring& clz) +shared_ptr cql_server::connection::make_autheticate(int16_t stream, const sstring& clz, const tracing::trace_state_ptr& tr_state) { - auto response = make_shared(stream, cql_binary_opcode::AUTHENTICATE); + auto response = make_shared(stream, cql_binary_opcode::AUTHENTICATE, tr_state); response->write_string(clz); return response; } -shared_ptr cql_server::connection::make_auth_success(int16_t stream, bytes b) { - auto response = make_shared(stream, cql_binary_opcode::AUTH_SUCCESS); +shared_ptr cql_server::connection::make_auth_success(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) { + auto response = make_shared(stream, cql_binary_opcode::AUTH_SUCCESS, tr_state); response->write_bytes(std::move(b)); return response; } -shared_ptr cql_server::connection::make_auth_challenge(int16_t stream, bytes b) { - auto response = make_shared(stream, cql_binary_opcode::AUTH_CHALLENGE); +shared_ptr cql_server::connection::make_auth_challenge(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) { + auto response = make_shared(stream, cql_binary_opcode::AUTH_CHALLENGE, tr_state); response->write_bytes(std::move(b)); return response; } -shared_ptr cql_server::connection::make_supported(int16_t stream) +shared_ptr cql_server::connection::make_supported(int16_t stream, const tracing::trace_state_ptr& tr_state) { std::multimap opts; opts.insert({"CQL_VERSION", cql3::query_processor::CQL_VERSION}); opts.insert({"COMPRESSION", "lz4"}); opts.insert({"COMPRESSION", "snappy"}); - auto response = make_shared(stream, cql_binary_opcode::SUPPORTED); + auto response = make_shared(stream, cql_binary_opcode::SUPPORTED, tr_state); response->write_string_multimap(opts); return response; } @@ -1156,9 +1146,9 @@ public: }; shared_ptr -cql_server::connection::make_result(int16_t stream, shared_ptr msg, bool skip_metadata) +cql_server::connection::make_result(int16_t stream, shared_ptr msg, const tracing::trace_state_ptr& tr_state, bool skip_metadata) { - auto response = make_shared(stream, cql_binary_opcode::RESULT); + auto response = make_shared(stream, cql_binary_opcode::RESULT, tr_state); fmt_visitor fmt{_version, response, skip_metadata}; msg->accept(fmt); return response; @@ -1167,7 +1157,7 @@ cql_server::connection::make_result(int16_t stream, shared_ptr cql_server::connection::make_topology_change_event(const event::topology_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("TOPOLOGY_CHANGE"); response->write_string(to_string(event.change)); response->write_inet(event.node); @@ -1177,7 +1167,7 @@ cql_server::connection::make_topology_change_event(const event::topology_change& shared_ptr cql_server::connection::make_status_change_event(const event::status_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("STATUS_CHANGE"); response->write_string(to_string(event.status)); response->write_inet(event.node); @@ -1187,7 +1177,7 @@ cql_server::connection::make_status_change_event(const event::status_change& eve shared_ptr cql_server::connection::make_schema_change_event(const event::schema_change& event) { - auto response = make_shared(-1, cql_binary_opcode::EVENT); + auto response = make_shared(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr()); response->write_string("SCHEMA_CHANGE"); response->serialize(event, _version); return response; @@ -1496,7 +1486,7 @@ cql3::raw_value_view cql_server::connection::read_value_view(bytes_view& buf) { scattered_message cql_server::response::make_message(uint8_t version) { scattered_message msg; sstring body{_body.data(), _body.size()}; - sstring frame = make_frame(version, 0x00, body.size()); + sstring frame = make_frame(version, _body.size()); msg.append(std::move(frame)); msg.append(std::move(body)); return msg; @@ -1504,12 +1494,10 @@ scattered_message cql_server::response::make_message(uint8_t version) { future<> cql_server::response::output(output_stream& out, uint8_t version, cql_compression compression) { - uint8_t flags = 0; if (compression != cql_compression::none) { - flags |= cql_frame_flags::compression; - _body = compress(_body, compression); + compress(compression); } - auto frame = make_frame(version, flags, _body.size()); + auto frame = make_frame(version, _body.size()); auto tmp = temporary_buffer(frame.size()); std::copy_n(frame.begin(), frame.size(), tmp.get_write()); auto f = out.write(tmp.get(), tmp.size()); @@ -1518,13 +1506,19 @@ cql_server::response::output(output_stream& out, uint8_t version, cql_comp }); } -std::vector cql_server::response::compress(const std::vector& body, cql_compression compression) +void cql_server::response::compress(cql_compression compression) { switch (compression) { - case cql_compression::lz4: return compress_lz4(body); - case cql_compression::snappy: return compress_snappy(body); - default: throw std::invalid_argument("Invalid CQL compression algorithm"); + case cql_compression::lz4: + _body = compress_lz4(_body); + break; + case cql_compression::snappy: + _body = compress_snappy(_body); + break; + default: + throw std::invalid_argument("Invalid CQL compression algorithm"); } + set_frame_flag(cql_frame_flags::compression); } std::vector cql_server::response::compress_lz4(const std::vector& body) diff --git a/transport/server.hh b/transport/server.hh index 603cd927fb..b47ceeaf0a 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -177,21 +177,21 @@ private: future process_batch(uint16_t stream, bytes_view buf, service::client_state client_state); future process_register(uint16_t stream, bytes_view buf, service::client_state client_state); - shared_ptr make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive); - shared_ptr make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present); - shared_ptr make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type); - shared_ptr make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name); - shared_ptr make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id); - shared_ptr make_error(int16_t stream, exceptions::exception_code err, sstring msg); - shared_ptr make_ready(int16_t stream); - shared_ptr make_supported(int16_t stream); - shared_ptr make_result(int16_t stream, shared_ptr msg, bool skip_metadata = false); + shared_ptr make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive, const tracing::trace_state_ptr& tr_state); + shared_ptr make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state); + shared_ptr make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state); + shared_ptr make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name, const tracing::trace_state_ptr& tr_state); + shared_ptr make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id, const tracing::trace_state_ptr& tr_state); + shared_ptr make_error(int16_t stream, exceptions::exception_code err, sstring msg, const tracing::trace_state_ptr& tr_state); + shared_ptr make_ready(int16_t stream, const tracing::trace_state_ptr& tr_state); + shared_ptr make_supported(int16_t stream, const tracing::trace_state_ptr& tr_state); + shared_ptr make_result(int16_t stream, shared_ptr msg, const tracing::trace_state_ptr& tr_state, bool skip_metadata = false); shared_ptr make_topology_change_event(const transport::event::topology_change& event); shared_ptr make_status_change_event(const transport::event::status_change& event); shared_ptr make_schema_change_event(const transport::event::schema_change& event); - shared_ptr make_autheticate(int16_t, const sstring&); - shared_ptr make_auth_success(int16_t, bytes); - shared_ptr make_auth_challenge(int16_t, bytes); + shared_ptr make_autheticate(int16_t, const sstring&, const tracing::trace_state_ptr& tr_state); + shared_ptr make_auth_success(int16_t, bytes, const tracing::trace_state_ptr& tr_state); + shared_ptr make_auth_challenge(int16_t, bytes, const tracing::trace_state_ptr& tr_state); future<> write_response(foreign_ptr>&& response, cql_compression compression = cql_compression::none); diff --git a/types.cc b/types.cc index 81a7199b3f..a51326ab49 100644 --- a/types.cc +++ b/types.cc @@ -520,7 +520,7 @@ struct timeuuid_type_impl : public concrete_type { return; } auto uuid = uuid1.get(); - out = std::copy_n(uuid.to_bytes().begin(), sizeof(uuid), out); + uuid.serialize(out); } virtual size_t serialized_size(const void* value) const override { if (!value || from_value(value).empty()) { @@ -583,7 +583,7 @@ struct timeuuid_type_impl : public concrete_type { if (v.version() != 1) { throw marshal_exception(); } - return v.to_bytes(); + return v.serialize(); } virtual sstring to_string(const bytes& b) const override { auto v = deserialize(b); @@ -972,8 +972,7 @@ struct uuid_type_impl : concrete_type { if (!value) { return; } - auto& uuid = from_value(value); - out = std::copy_n(uuid.get().to_bytes().begin(), sizeof(uuid.get()), out); + from_value(value).get().serialize(out); } virtual size_t serialized_size(const void* value) const override { if (!value) { @@ -1035,7 +1034,7 @@ struct uuid_type_impl : concrete_type { throw marshal_exception(); } utils::UUID v(s); - return v.to_bytes(); + return v.serialize(); } virtual sstring to_string(const bytes& b) const override { auto v = deserialize(b); diff --git a/utils/UUID.hh b/utils/UUID.hh index 96ffed59df..f12bc86ecc 100644 --- a/utils/UUID.hh +++ b/utils/UUID.hh @@ -113,13 +113,22 @@ public: return !(*this < v); } - bytes to_bytes() const { - bytes b(bytes::initialized_later(),16); + bytes serialize() const { + bytes b(bytes::initialized_later(), serialized_size()); auto i = b.begin(); - serialize_int64(i, most_sig_bits); - serialize_int64(i, least_sig_bits); + serialize(i); return b; } + + static size_t serialized_size() noexcept { + return 16; + } + + template + void serialize(CharOutputIterator& out) const { + serialize_int64(out, most_sig_bits); + serialize_int64(out, least_sig_bits); + } }; UUID make_random_uuid(); diff --git a/utils/serialization.hh b/utils/serialization.hh index 752ab07fe2..382492ed10 100644 --- a/utils/serialization.hh +++ b/utils/serialization.hh @@ -40,36 +40,65 @@ #include +#include #include "core/sstring.hh" #include "net/byteorder.hh" #include "bytes.hh" #include +#include class UTFDataFormatException { }; class EOFException { }; -inline -void serialize_int8(std::ostream& out, uint8_t val) { - out.put(val); -} +static constexpr size_t serialize_int8_size = 1; +static constexpr size_t serialize_bool_size = 1; +static constexpr size_t serialize_int16_size = 2; +static constexpr size_t serialize_int32_size = 4; +static constexpr size_t serialize_int64_size = 8; -inline -void serialize_int8(std::ostream& out, int8_t val) { - out.put(val); -} +namespace internal { +template +GCC6_CONCEPT(requires std::is_integral::value && std::is_integral::value && requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -void serialize_int8(bytes::iterator& out, uint8_t val) { - uint8_t nval = net::hton(val); +void serialize_int(CharOutputIterator& out, IntegerType val) { + ExplicitIntegerType nval = net::hton(ExplicitIntegerType(val)); out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); } -static constexpr size_t serialize_int8_size = 1; +} +template inline -void serialize_int8(std::ostream& out, char val) { - out.put(val); +void serialize_int8(CharOutputIterator& out, uint8_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int16(CharOutputIterator& out, uint16_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int32(CharOutputIterator& out, uint32_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_int64(CharOutputIterator& out, uint64_t val) { + internal::serialize_int(out, val); +} + +template +inline +void serialize_bool(CharOutputIterator& out, bool val) { + serialize_int8(out, val ? 1 : 0); } inline @@ -82,44 +111,9 @@ int8_t deserialize_int8(std::istream& in) { } } -inline -void serialize_bool(std::ostream& out, bool b) { - out.put(b ? (char)1 : (char)0); -} - -static constexpr size_t serialize_bool_size = 1; - -inline -void serialize_bool(bytes::iterator& out, bool val) { - serialize_int8(out, val ? 1 : 0); -} - inline bool deserialize_bool(std::istream& in) { - char ret; - if (in.get(ret)) { - return ret; - } else { - throw EOFException(); - } -} - - -inline -void serialize_int16(std::ostream& out, uint16_t val) { - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int16(std::ostream& out, int16_t val) { - serialize_int16(out, (uint16_t) val); -} - -inline -void serialize_int16(bytes::iterator& out, uint16_t val) { - uint16_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); + return deserialize_int8(in); } inline @@ -133,29 +127,6 @@ int16_t deserialize_int16(std::istream& in) { return ((int16_t)(uint8_t)a1 << 8) | ((int16_t)(uint8_t)a2 << 0); } -static constexpr size_t serialize_int16_size = 2; - -inline -void serialize_int32(std::ostream& out, uint32_t val) { - out.put((char)((val >> 24) & 0xFF)); - out.put((char)((val >> 16) & 0xFF)); - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int32(std::ostream& out, int32_t val) { - serialize_int32(out, (uint32_t) val); -} - -inline -void serialize_int32(bytes::iterator& out, uint32_t val) { - uint32_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); -} - -static constexpr size_t serialize_int32_size = 4; - inline int32_t deserialize_int32(std::istream& in) { char a1, a2, a3, a4; @@ -169,31 +140,6 @@ int32_t deserialize_int32(std::istream& in) { ((int32_t)(uint8_t)a4 << 0); } -inline -void serialize_int64(std::ostream& out, uint64_t val) { - out.put((char)((val >> 56) & 0xFF)); - out.put((char)((val >> 48) & 0xFF)); - out.put((char)((val >> 40) & 0xFF)); - out.put((char)((val >> 32) & 0xFF)); - out.put((char)((val >> 24) & 0xFF)); - out.put((char)((val >> 16) & 0xFF)); - out.put((char)((val >> 8) & 0xFF)); - out.put((char)((val >> 0) & 0xFF)); -} - -inline -void serialize_int64(std::ostream& out, int64_t val) { - serialize_int64(out, (uint64_t) val); -} - -inline -void serialize_int64(bytes::iterator& out, uint64_t val) { - uint64_t nval = net::hton(val); - out = std::copy_n(reinterpret_cast(&nval), sizeof(nval), out); -} - -static constexpr size_t serialize_int64_size = 8; - inline int64_t deserialize_int64(std::istream& in) { char a1, a2, a3, a4, a5, a6, a7, a8; @@ -222,8 +168,12 @@ int64_t deserialize_int64(std::istream& in) { // http://docs.oracle.com/javase/7/docs/api/java/io/DataInput.html#modified-utf-8) // For now we'll just assume those aren't in the string... // TODO: fix the compatibility with Java even in this case. +template +GCC6_CONCEPT(requires requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -void serialize_string(std::ostream& out, const sstring& s) { +void serialize_string(CharOutputIterator& out, const sstring& s) { // Java specifies that nulls in the string need to be replaced by the // two bytes 0xC0, 0x80. Let's not bother with such transformation // now, but just verify wasn't needed. @@ -237,33 +187,16 @@ void serialize_string(std::ostream& out, const sstring& s) { // can't serialize longer strings. throw UTFDataFormatException(); } - serialize_int16(out, (uint16_t) s.size()); - out.write(s.c_str(), s.size()); -} - -inline -void serialize_string(bytes::iterator& out, const sstring& s) { - for (char c : s) { - if (c == '\0') { - throw UTFDataFormatException(); - } - } - if (s.size() > std::numeric_limits::max()) { - throw UTFDataFormatException(); - } - serialize_int16(out, (uint16_t) s.size()); + serialize_int16(out, s.size()); out = std::copy(s.begin(), s.end(), out); } +template +GCC6_CONCEPT(requires requires (CharOutputIterator it) { + *it++ = 'a'; +}) inline -size_t serialize_string_size(const sstring& s) {; - // As above, this code is missing the case of modified utf-8 - return serialize_int16_size + s.size(); -} - - -inline -void serialize_string(std::ostream& out, const char *s) { +void serialize_string(CharOutputIterator& out, const char* s) { // TODO: like above, need to change UTF-8 when above 16-bit. auto len = strlen(s); if (len > std::numeric_limits::max()) { @@ -271,8 +204,14 @@ void serialize_string(std::ostream& out, const char *s) { // can't serialize longer strings. throw UTFDataFormatException(); } - serialize_int16(out, (uint16_t) len); - out.write(s, len); + serialize_int16(out, len); + out = std::copy_n(s, len, out); +} + +inline +size_t serialize_string_size(const sstring& s) {; + // As above, this code is missing the case of modified utf-8 + return serialize_int16_size + s.size(); } inline