diff --git a/transport/server.cc b/transport/server.cc index 448b128123..13f16c20df 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -37,12 +37,6 @@ struct cql_frame_error : std::exception { } }; -struct bad_cql_protocol_version : std::exception { - const char* what() const throw () override { - return "bad cql binary protocol version"; - } -}; - struct [[gnu::packed]] cql_binary_frame_v1 { uint8_t version; uint8_t flags; @@ -207,8 +201,22 @@ public: _server._notifier->unregister_connection(this); } future<> process() { - return do_until([this] { return _read_buf.eof(); }, [this] { return process_request(); }) - .finally([this] { + return do_until([this] { + return _read_buf.eof(); + }, [this] { + return process_request(); + }).then_wrapped([this] (future<> f) { + try { + f.get(); + return make_ready_future<>(); + } catch (const exceptions::cassandra_exception& ex) { + return write_error(0, ex.code(), ex.what()); + } catch (std::exception& ex) { + return write_error(0, exceptions::exception_code::SERVER_ERROR, ex.what()); + } catch (...) { + return write_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error"); + } + }).finally([this] { return _pending_requests_gate.close().then([this] { return std::move(_ready_to_respond); }); @@ -625,10 +633,11 @@ cql_server::connection::parse_frame(temporary_buffer buf) { break; } default: - abort(); + throw exceptions::protocol_exception(sprint("Invalid or unsupported protocol version: %d", _version)); } if (v3.version != _version) { - throw bad_cql_protocol_version(); + throw exceptions::protocol_exception(sprint("Invalid message version. Got %d but previous messages on this connection had version %d", v3.version, _version)); + } return v3; } @@ -645,8 +654,8 @@ cql_server::connection::read_frame() { } _version = buf[0]; init_serialization_format(); - if (_version < 1 || _version > 4) { - throw bad_cql_protocol_version(); + if (_version < 1 || _version > current_version) { + throw exceptions::protocol_exception(sprint("Invalid or unsupported protocol version: %d", _version)); } return _read_buf.read_exactly(frame_size() - 1).then([this] (temporary_buffer tail) { temporary_buffer full(frame_size()); diff --git a/transport/server.hh b/transport/server.hh index 19669780b3..7e1c283afb 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -24,6 +24,8 @@ class database; class cql_server { class event_notifier; + static constexpr int current_version = 3; + std::vector _listeners; distributed& _proxy; distributed& _query_processor;