diff --git a/exceptions/exceptions.hh b/exceptions/exceptions.hh index 0e11b033aa..28e318ac01 100644 --- a/exceptions/exceptions.hh +++ b/exceptions/exceptions.hh @@ -55,13 +55,7 @@ enum class exception_code : int32_t { UNPREPARED = 0x2500 }; -class transport_exception { -public: - virtual exception_code code() const = 0; - virtual sstring get_message() const = 0; -}; - -class cassandra_exception : public std::exception, public transport_exception { +class cassandra_exception : public std::exception { private: exception_code _code; sstring _msg; @@ -71,8 +65,8 @@ public: , _msg(std::move(msg)) { } virtual const char* what() const noexcept override { return _msg.begin(); } - virtual exception_code code() const override { return _code; } - virtual sstring get_message() const override { return what(); } + exception_code code() const { return _code; } + sstring get_message() const { return what(); } }; class request_validation_exception : public cassandra_exception { diff --git a/transport/protocol_exception.hh b/transport/protocol_exception.hh index c20f5de52d..95b7933bf3 100644 --- a/transport/protocol_exception.hh +++ b/transport/protocol_exception.hh @@ -28,18 +28,11 @@ namespace transport { -class protocol_exception : public std::exception, public exceptions::transport_exception { -private: - exceptions::exception_code _code; - sstring _msg; +class protocol_exception : public exceptions::cassandra_exception { public: protocol_exception(sstring msg) - : _code(exceptions::exception_code::PROTOCOL_ERROR) - , _msg(std::move(msg)) + : exceptions::cassandra_exception{exceptions::exception_code::PROTOCOL_ERROR, std::move(msg)} { } - virtual const char* what() const noexcept override { return _msg.begin(); } - virtual exceptions::exception_code code() const override { return _code; } - virtual sstring get_message() const override { return _msg; } }; } diff --git a/transport/server.cc b/transport/server.cc index b57afda34b..a24a266c94 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "db/consistency_level.hh" #include "core/future-util.hh" @@ -175,6 +176,14 @@ private: } } + void validate_utf8(sstring_view s) { + try { + boost::locale::conv::utf_to_utf(s.data(), boost::locale::conv::stop); + } catch (const boost::locale::conv::conversion_error& ex) { + throw transport::protocol_exception("Cannot decode string as UTF8"); + } + } + int8_t read_byte(temporary_buffer& buf); int32_t read_int(temporary_buffer& buf); int64_t read_long(temporary_buffer& buf); @@ -677,6 +686,7 @@ sstring cql_server::connection::read_string(temporary_buffer& buf) sstring s{buf.begin(), static_cast(n)}; assert(n >= 0); buf.trim_front(n); + validate_utf8(s); return s; } @@ -686,6 +696,7 @@ sstring_view cql_server::connection::read_long_string_view(temporary_buffer(n)}; buf.trim_front(n); + validate_utf8(s); return s; }