Merge "Fix limits handling in CQL server" from Tomasz

"Fixes the following issues:
 #807 Wrong maximum key length
 #809 Scylla assert on returning result when max column size is overflow"
This commit is contained in:
Avi Kivity
2016-02-17 15:06:51 +02:00
3 changed files with 30 additions and 26 deletions

View File

@@ -43,6 +43,7 @@ public:
static constexpr bool is_prefixable = AllowPrefixes == allow_prefixes::yes;
using prefix_type = compound_type<allow_prefixes::yes>;
using value_type = std::vector<bytes>;
using size_type = uint16_t;
compound_type(std::vector<data_type> types)
: _types(std::move(types))
@@ -75,8 +76,8 @@ private:
template<typename RangeOfSerializedComponents>
static void serialize_value(RangeOfSerializedComponents&& values, bytes::iterator& out) {
for (auto&& val : values) {
assert(val.size() <= std::numeric_limits<uint16_t>::max());
write<uint16_t>(out, uint16_t(val.size()));
assert(val.size() <= std::numeric_limits<size_type>::max());
write<size_type>(out, size_type(val.size()));
out = std::copy(val.begin(), val.end(), out);
}
}
@@ -84,7 +85,7 @@ private:
static size_t serialized_size(RangeOfSerializedComponents&& values) {
size_t len = 0;
for (auto&& val : values) {
len += sizeof(uint16_t) + val.size();
len += sizeof(size_type) + val.size();
}
return len;
}
@@ -95,8 +96,8 @@ public:
template<typename RangeOfSerializedComponents>
static bytes serialize_value(RangeOfSerializedComponents&& values) {
auto size = serialized_size(values);
if (size > std::numeric_limits<uint16_t>::max()) {
throw std::runtime_error(sprint("Key size too large: %d > %d", size, std::numeric_limits<uint16_t>::max()));
if (size > std::numeric_limits<size_type>::max()) {
throw std::runtime_error(sprint("Key size too large: %d > %d", size, std::numeric_limits<size_type>::max()));
}
bytes b(bytes::initialized_later(), size);
auto i = b.begin();
@@ -135,13 +136,13 @@ public:
value_type _current;
private:
void read_current() {
uint16_t len;
size_type len;
{
if (_v.empty()) {
_v = bytes_view(nullptr, 0);
return;
}
len = read_simple<uint16_t>(_v);
len = read_simple<size_type>(_v);
if (_v.size() < len) {
throw marshal_exception();
}

View File

@@ -191,7 +191,7 @@ public:
void write_byte(uint8_t b);
void write_int(int32_t n);
void write_long(int64_t n);
void write_short(int16_t n);
void write_short(uint16_t n);
void write_string(const sstring& s);
void write_long_string(const sstring& s);
void write_uuid(utils::UUID uuid);
@@ -1048,9 +1048,9 @@ int64_t cql_server::connection::read_long(bytes_view& buf)
return n;
}
int16_t cql_server::connection::read_short(bytes_view& buf)
uint16_t cql_server::connection::read_short(bytes_view& buf)
{
return static_cast<int16_t>(read_unsigned_short(buf));
return read_unsigned_short(buf);
}
uint16_t cql_server::connection::read_unsigned_short(bytes_view& buf)
@@ -1357,24 +1357,32 @@ void cql_server::response::write_long(int64_t n)
_body.insert(_body.end(), s, s+sizeof(u));
}
void cql_server::response::write_short(int16_t n)
void cql_server::response::write_short(uint16_t n)
{
auto u = htons(n);
auto *s = reinterpret_cast<const char*>(&u);
_body.insert(_body.end(), s, s+sizeof(u));
}
template<typename T>
inline
T cast_if_fits(size_t v) {
size_t max = std::numeric_limits<T>::max();
if (v > max) {
throw std::runtime_error(sprint("Value to large, %d > %d", v, max));
}
return static_cast<T>(v);
}
void cql_server::response::write_string(const sstring& s)
{
assert(s.size() < std::numeric_limits<int16_t>::max());
write_short(s.size());
write_short(cast_if_fits<uint16_t>(s.size()));
_body.insert(_body.end(), s.begin(), s.end());
}
void cql_server::response::write_long_string(const sstring& s)
{
assert(s.size() < std::numeric_limits<int32_t>::max());
write_int(s.size());
write_int(cast_if_fits<int32_t>(s.size()));
_body.insert(_body.end(), s.begin(), s.end());
}
@@ -1386,8 +1394,7 @@ void cql_server::response::write_uuid(utils::UUID uuid)
void cql_server::response::write_string_list(std::vector<sstring> string_list)
{
assert(string_list.size() < std::numeric_limits<int16_t>::max());
write_short(string_list.size());
write_short(cast_if_fits<uint16_t>(string_list.size()));
for (auto&& s : string_list) {
write_string(s);
}
@@ -1395,15 +1402,13 @@ void cql_server::response::write_string_list(std::vector<sstring> string_list)
void cql_server::response::write_bytes(bytes b)
{
assert(b.size() < std::numeric_limits<int32_t>::max());
write_int(b.size());
write_int(cast_if_fits<int32_t>(b.size()));
_body.insert(_body.end(), b.begin(), b.end());
}
void cql_server::response::write_short_bytes(bytes b)
{
assert(b.size() < std::numeric_limits<int16_t>::max());
write_short(b.size());
write_short(cast_if_fits<uint16_t>(b.size()));
_body.insert(_body.end(), b.begin(), b.end());
}
@@ -1436,8 +1441,7 @@ void cql_server::response::write_consistency(db::consistency_level c)
void cql_server::response::write_string_map(std::map<sstring, sstring> string_map)
{
assert(string_map.size() < std::numeric_limits<int16_t>::max());
write_short(string_map.size());
write_short(cast_if_fits<uint16_t>(string_map.size()));
for (auto&& s : string_map) {
write_string(s.first);
write_string(s.second);
@@ -1450,8 +1454,7 @@ void cql_server::response::write_string_multimap(std::multimap<sstring, sstring>
for (auto it = string_map.begin(), end = string_map.end(); it != end; it = string_map.upper_bound(it->first)) {
keys.push_back(it->first);
}
assert(keys.size() < std::numeric_limits<int16_t>::max());
write_short(keys.size());
write_short(cast_if_fits<uint16_t>(keys.size()));
for (auto&& key : keys) {
std::vector<sstring> values;
auto range = string_map.equal_range(key);

View File

@@ -178,7 +178,7 @@ private:
int8_t read_byte(bytes_view& buf);
int32_t read_int(bytes_view& buf);
int64_t read_long(bytes_view& buf);
int16_t read_short(bytes_view& buf);
uint16_t read_short(bytes_view& buf);
uint16_t read_unsigned_short(bytes_view& buf);
sstring read_string(bytes_view& buf);
sstring_view read_string_view(bytes_view& buf);