transport: server: Guard against buffer overrun when parsing CQL3 frame

This commit is contained in:
Tomasz Grabiec
2015-02-13 15:26:20 +01:00
parent a51460508d
commit a37baeb81b

View File

@@ -12,6 +12,7 @@
#include "net/byteorder.hh"
#include "cql3/CqlParser.hpp"
#include "transport/protocol_exception.hh"
#include <cassert>
#include <string>
@@ -163,6 +164,12 @@ private:
future<> write_supported(int16_t stream);
future<> write_response(shared_ptr<cql_server::response> response);
void check_room(temporary_buffer<char>& buf, size_t n) {
if (buf.size() < n) {
throw transport::protocol_exception("truncated frame");
}
}
int8_t read_byte(temporary_buffer<char>& buf);
int32_t read_int(temporary_buffer<char>& buf);
int64_t read_long(temporary_buffer<char>& buf);
@@ -430,6 +437,7 @@ future<> cql_server::connection::write_response(shared_ptr<cql_server::response>
int8_t cql_server::connection::read_byte(temporary_buffer<char>& buf)
{
check_room(buf, 1);
int8_t n = buf[0];
buf.trim_front(1);
return n;
@@ -437,6 +445,7 @@ int8_t cql_server::connection::read_byte(temporary_buffer<char>& buf)
int32_t cql_server::connection::read_int(temporary_buffer<char>& buf)
{
check_room(buf, sizeof(int32_t));
auto p = reinterpret_cast<const uint8_t*>(buf.begin());
uint32_t n = (static_cast<uint32_t>(p[0]) << 24)
| (static_cast<uint32_t>(p[1]) << 16)
@@ -448,6 +457,7 @@ int32_t cql_server::connection::read_int(temporary_buffer<char>& buf)
int64_t cql_server::connection::read_long(temporary_buffer<char>& buf)
{
check_room(buf, sizeof(int64_t));
auto p = reinterpret_cast<const uint8_t*>(buf.begin());
uint64_t n = (static_cast<uint64_t>(p[0]) << 56)
| (static_cast<uint64_t>(p[1]) << 48)
@@ -463,6 +473,7 @@ int64_t cql_server::connection::read_long(temporary_buffer<char>& buf)
int16_t cql_server::connection::read_short(temporary_buffer<char>& buf)
{
check_room(buf, sizeof(uint16_t));
auto p = reinterpret_cast<const uint8_t*>(buf.begin());
uint16_t n = (static_cast<uint16_t>(p[0]) << 8)
| (static_cast<uint16_t>(p[1]));
@@ -473,6 +484,7 @@ int16_t cql_server::connection::read_short(temporary_buffer<char>& buf)
sstring cql_server::connection::read_string(temporary_buffer<char>& buf)
{
auto n = read_short(buf);
check_room(buf, n);
sstring s{buf.begin(), static_cast<size_t>(n)};
assert(n >= 0);
buf.trim_front(n);
@@ -482,6 +494,7 @@ sstring cql_server::connection::read_string(temporary_buffer<char>& buf)
sstring cql_server::connection::read_long_string(temporary_buffer<char>& buf)
{
auto n = read_int(buf);
check_room(buf, n);
sstring s{buf.begin(), static_cast<size_t>(n)};
buf.trim_front(n);
return s;