diff --git a/compound.hh b/compound.hh index da3b7eeebf..e702a269de 100644 --- a/compound.hh +++ b/compound.hh @@ -43,6 +43,7 @@ public: static constexpr bool is_prefixable = AllowPrefixes == allow_prefixes::yes; using prefix_type = compound_type; using value_type = std::vector; + using size_type = uint16_t; compound_type(std::vector types) : _types(std::move(types)) @@ -75,8 +76,8 @@ private: template static void serialize_value(RangeOfSerializedComponents&& values, bytes::iterator& out) { for (auto&& val : values) { - assert(val.size() <= std::numeric_limits::max()); - write(out, uint16_t(val.size())); + assert(val.size() <= std::numeric_limits::max()); + write(out, size_type(val.size())); out = std::copy(val.begin(), val.end(), out); } } @@ -84,7 +85,7 @@ private: static size_t serialized_size(RangeOfSerializedComponents&& values) { size_t len = 0; for (auto&& val : values) { - len += sizeof(uint16_t) + val.size(); + len += sizeof(size_type) + val.size(); } return len; } @@ -95,8 +96,8 @@ public: template static bytes serialize_value(RangeOfSerializedComponents&& values) { auto size = serialized_size(values); - if (size > std::numeric_limits::max()) { - throw std::runtime_error(sprint("Key size too large: %d > %d", size, std::numeric_limits::max())); + if (size > std::numeric_limits::max()) { + throw std::runtime_error(sprint("Key size too large: %d > %d", size, std::numeric_limits::max())); } bytes b(bytes::initialized_later(), size); auto i = b.begin(); @@ -135,13 +136,13 @@ public: value_type _current; private: void read_current() { - uint16_t len; + size_type len; { if (_v.empty()) { _v = bytes_view(nullptr, 0); return; } - len = read_simple(_v); + len = read_simple(_v); if (_v.size() < len) { throw marshal_exception(); } diff --git a/transport/server.cc b/transport/server.cc index d877f99bd1..c652483210 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -191,7 +191,7 @@ public: void write_byte(uint8_t b); void write_int(int32_t n); void write_long(int64_t n); - void write_short(int16_t n); + void write_short(uint16_t n); void write_string(const sstring& s); void write_long_string(const sstring& s); void write_uuid(utils::UUID uuid); @@ -1048,9 +1048,9 @@ int64_t cql_server::connection::read_long(bytes_view& buf) return n; } -int16_t cql_server::connection::read_short(bytes_view& buf) +uint16_t cql_server::connection::read_short(bytes_view& buf) { - return static_cast(read_unsigned_short(buf)); + return read_unsigned_short(buf); } uint16_t cql_server::connection::read_unsigned_short(bytes_view& buf) @@ -1357,24 +1357,32 @@ void cql_server::response::write_long(int64_t n) _body.insert(_body.end(), s, s+sizeof(u)); } -void cql_server::response::write_short(int16_t n) +void cql_server::response::write_short(uint16_t n) { auto u = htons(n); auto *s = reinterpret_cast(&u); _body.insert(_body.end(), s, s+sizeof(u)); } +template +inline +T cast_if_fits(size_t v) { + size_t max = std::numeric_limits::max(); + if (v > max) { + throw std::runtime_error(sprint("Value to large, %d > %d", v, max)); + } + return static_cast(v); +} + void cql_server::response::write_string(const sstring& s) { - assert(s.size() < std::numeric_limits::max()); - write_short(s.size()); + write_short(cast_if_fits(s.size())); _body.insert(_body.end(), s.begin(), s.end()); } void cql_server::response::write_long_string(const sstring& s) { - assert(s.size() < std::numeric_limits::max()); - write_int(s.size()); + write_int(cast_if_fits(s.size())); _body.insert(_body.end(), s.begin(), s.end()); } @@ -1386,8 +1394,7 @@ void cql_server::response::write_uuid(utils::UUID uuid) void cql_server::response::write_string_list(std::vector string_list) { - assert(string_list.size() < std::numeric_limits::max()); - write_short(string_list.size()); + write_short(cast_if_fits(string_list.size())); for (auto&& s : string_list) { write_string(s); } @@ -1395,15 +1402,13 @@ void cql_server::response::write_string_list(std::vector string_list) void cql_server::response::write_bytes(bytes b) { - assert(b.size() < std::numeric_limits::max()); - write_int(b.size()); + write_int(cast_if_fits(b.size())); _body.insert(_body.end(), b.begin(), b.end()); } void cql_server::response::write_short_bytes(bytes b) { - assert(b.size() < std::numeric_limits::max()); - write_short(b.size()); + write_short(cast_if_fits(b.size())); _body.insert(_body.end(), b.begin(), b.end()); } @@ -1436,8 +1441,7 @@ void cql_server::response::write_consistency(db::consistency_level c) void cql_server::response::write_string_map(std::map string_map) { - assert(string_map.size() < std::numeric_limits::max()); - write_short(string_map.size()); + write_short(cast_if_fits(string_map.size())); for (auto&& s : string_map) { write_string(s.first); write_string(s.second); @@ -1450,8 +1454,7 @@ void cql_server::response::write_string_multimap(std::multimap for (auto it = string_map.begin(), end = string_map.end(); it != end; it = string_map.upper_bound(it->first)) { keys.push_back(it->first); } - assert(keys.size() < std::numeric_limits::max()); - write_short(keys.size()); + write_short(cast_if_fits(keys.size())); for (auto&& key : keys) { std::vector values; auto range = string_map.equal_range(key); diff --git a/transport/server.hh b/transport/server.hh index 1f1d4c2097..a68e01eb23 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -178,7 +178,7 @@ private: int8_t read_byte(bytes_view& buf); int32_t read_int(bytes_view& buf); int64_t read_long(bytes_view& buf); - int16_t read_short(bytes_view& buf); + uint16_t read_short(bytes_view& buf); uint16_t read_unsigned_short(bytes_view& buf); sstring read_string(bytes_view& buf); sstring_view read_string_view(bytes_view& buf);