Merge "cql3: detect and handle int overflow in aggregate functions #5537" from Benny

"
Fix overflow handling in sum() and avg().

sum:
 - aggregated into __int128
 - detect overflow when computing result and log a warning if found

avg:
 - fix division function to divide the accumulator type _sum (__int128 for integers) by _count

Add unit tests for both cases

Test:
  - manual test against Cassandra 3.11.3 to make sure the results in the scylla unit test agree with it.
  - unit(dev), cql_query_test(debug)

Fixes #5536
"

* 'cql3-sum-overflow' of https://github.com/bhalevy/scylla:
  test: cql_query_test: test avg overflow
  cql3: functions: protect against int overflow in avg
  test: cql_query_test: test sum overflow
  cql3: functions: detect and handle int overflow in sum
  exceptions: sort exception_code definitions
  exceptions: define additional cassandra CQL exceptions codes
This commit is contained in:
Avi Kivity
2020-01-08 10:39:38 +02:00
3 changed files with 193 additions and 54 deletions

View File

@@ -44,6 +44,7 @@
#include "aggregate_fcts.hh"
#include "functions.hh"
#include "native_aggregate_function.hh"
#include "exceptions/exceptions.hh"
using namespace cql3;
using namespace functions;
@@ -75,59 +76,8 @@ public:
}
};
template <typename Type>
class impl_sum_function_for final : public aggregate_function::aggregate {
Type _sum{};
public:
virtual void reset() override {
_sum = {};
}
virtual opt_bytes compute(cql_serialization_format sf) override {
return data_type_for<Type>()->decompose(_sum);
}
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
if (!values[0]) {
return;
}
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
}
};
template <typename Type>
class sum_function_for final : public native_aggregate_function {
public:
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
virtual std::unique_ptr<aggregate> new_aggregate() override {
return std::make_unique<impl_sum_function_for<Type>>();
}
};
template <typename Type>
static
shared_ptr<aggregate_function>
make_sum_function() {
return make_shared<sum_function_for<Type>>();
}
template <typename Type>
class impl_div_for_avg {
public:
static Type div(const Type& x, const int64_t y) {
return x/y;
}
};
template <>
class impl_div_for_avg<big_decimal> {
public:
static big_decimal div(const big_decimal& x, const int64_t y) {
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
}
};
// We need a wider accumulator for average, since summing the inputs can overflow
// the input type
// We need a wider accumulator for sum and average,
// since summing the inputs can overflow the input type
template <typename T>
struct accumulator_for;
@@ -171,6 +121,63 @@ struct accumulator_for<big_decimal> {
using type = big_decimal;
};
template <typename Type>
class impl_sum_function_for final : public aggregate_function::aggregate {
using accumulator_type = typename accumulator_for<Type>::type;
accumulator_type _sum{};
public:
virtual void reset() override {
_sum = {};
}
virtual opt_bytes compute(cql_serialization_format sf) override {
Type ret = static_cast<Type>(_sum);
if (static_cast<accumulator_type>(ret) != _sum) {
throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type.");
}
return data_type_for<Type>()->decompose(ret);
}
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
if (!values[0]) {
return;
}
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
}
};
template <typename Type>
class sum_function_for final : public native_aggregate_function {
public:
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
virtual std::unique_ptr<aggregate> new_aggregate() override {
return std::make_unique<impl_sum_function_for<Type>>();
}
};
template <typename Type>
static
shared_ptr<aggregate_function>
make_sum_function() {
return make_shared<sum_function_for<Type>>();
}
template <typename Type>
class impl_div_for_avg {
public:
static Type div(const typename accumulator_for<Type>::type& x, const int64_t y) {
return x/y;
}
};
template <>
class impl_div_for_avg<big_decimal> {
public:
static big_decimal div(const big_decimal& x, const int64_t y) {
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
}
};
template <typename Type>
class impl_avg_function_for final : public aggregate_function::aggregate {
typename accumulator_for<Type>::type _sum{};

View File

@@ -62,9 +62,15 @@ enum class exception_code : int32_t {
IS_BOOTSTRAPPING= 0x1002,
TRUNCATE_ERROR = 0x1003,
WRITE_TIMEOUT = 0x1100,
WRITE_FAILURE = 0x1500,
READ_TIMEOUT = 0x1200,
READ_FAILURE = 0x1300,
FUNCTION_FAILURE= 0x1400,
WRITE_FAILURE = 0x1500,
CDC_WRITE_FAILURE = 0x1600,
// Scylla-specific codes
// Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra.
OVERFLOW_ERROR = 0x1aff,
// 2xx: problem validating the request
SYNTAX_ERROR = 0x2000,
@@ -203,6 +209,12 @@ struct read_failure_exception : public request_failure_exception {
{ }
};
class overflow_error_exception: public cassandra_exception {
public:
overflow_error_exception(sstring msg) noexcept
: cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {}
};
struct overloaded_exception : public cassandra_exception {
overloaded_exception(size_t c) noexcept :
cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}

View File

@@ -4313,3 +4313,123 @@ SEASTAR_TEST_CASE(test_rf_expand) {
});
});
}
SEASTAR_TEST_CASE(test_int_sum_overflow) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, ck text, val int, primary key(pk, ck));");
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c1', 2147483647);");
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c2', 1);");
auto sum_query = "select sum(val) from cf;";
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p2', 'c1', -1);");
auto result = e.execute_cql(sum_query).get0();
assert_that(result)
.is_rows()
.with_size(1)
.with_row({int32_type->decompose(int32_t(2147483647))});
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c1', 2147483647);");
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c2', -2147483648);");
result = e.execute_cql(sum_query).get0();
assert_that(result)
.is_rows()
.with_size(1)
.with_row({int32_type->decompose(int32_t(2147483646))});
});
}
SEASTAR_TEST_CASE(test_bigint_sum_overflow) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, ck text, val bigint, primary key(pk, ck));");
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c1', 9223372036854775807);");
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c2', 1);");
auto sum_query = "select sum(val) from cf;";
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p2', 'c1', -1);");
auto result = e.execute_cql(sum_query).get0();
assert_that(result)
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(9223372036854775807))});
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c1', 9223372036854775807);");
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c2', -9223372036854775808);");
result = e.execute_cql(sum_query).get0();
assert_that(result)
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(9223372036854775806))});
});
}
SEASTAR_TEST_CASE(test_bigint_sum) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, val bigint, primary key(pk));");
cquery_nofail(e, "insert into cf (pk, val) values ('x', 2147483647);");
cquery_nofail(e, "insert into cf (pk, val) values ('y', 2147483647);");
auto sum_query = "select sum(val) from cf;";
assert_that(e.execute_cql(sum_query).get0())
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(4294967294))});
cquery_nofail(e, "insert into cf (pk, val) values ('z', -4294967295);");
assert_that(e.execute_cql(sum_query).get0())
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(-1))});
});
}
SEASTAR_TEST_CASE(test_int_sum_with_cast) {
return do_with_cql_env_thread([] (cql_test_env& e) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, val int, primary key(pk));");
cquery_nofail(e, "insert into cf (pk, val) values ('a', 2147483647);");
cquery_nofail(e, "insert into cf (pk, val) values ('b', 2147483647);");
auto sum_as_bigint_query = "select sum(val as bigint) from cf;";
assert_that(e.execute_cql(sum_as_bigint_query).get0())
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(4294967294))});
cquery_nofail(e, "insert into cf (pk, val) values ('a', -2147483648);");
cquery_nofail(e, "insert into cf (pk, val) values ('b', -2147483647);");
assert_that(e.execute_cql(sum_as_bigint_query).get0())
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(-4294967296))});
});
});
}
SEASTAR_TEST_CASE(test_int_avg) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, val int, primary key(pk));");
cquery_nofail(e, "insert into cf (pk, val) values ('a', 2147483647);");
cquery_nofail(e, "insert into cf (pk, val) values ('b', 2147483647);");
auto result = e.execute_cql("select avg(val) from cf;").get0();
assert_that(result)
.is_rows()
.with_size(1)
.with_row({int32_type->decompose(int32_t(2147483647))});
});
}
SEASTAR_TEST_CASE(test_bigint_avg) {
return do_with_cql_env_thread([] (cql_test_env& e) {
cquery_nofail(e, "create table cf (pk text, val bigint, primary key(pk));");
cquery_nofail(e, "insert into cf (pk, val) values ('x', 9223372036854775807);");
cquery_nofail(e, "insert into cf (pk, val) values ('y', 9223372036854775807);");
assert_that(e.execute_cql("select avg(val) from cf;").get0())
.is_rows()
.with_size(1)
.with_row({long_type->decompose(int64_t(9223372036854775807))});
});
}