Merge "cql3: detect and handle int overflow in aggregate functions #5537" from Benny
" Fix overflow handling in sum() and avg(). sum: - aggregated into __int128 - detect overflow when computing result and log a warning if found avg: - fix division function to divide the accumulator type _sum (__int128 for integers) by _count Add unit tests for both cases Test: - manual test against Cassandra 3.11.3 to make sure the results in the scylla unit test agree with it. - unit(dev), cql_query_test(debug) Fixes #5536 " * 'cql3-sum-overflow' of https://github.com/bhalevy/scylla: test: cql_query_test: test avg overflow cql3: functions: protect against int overflow in avg test: cql_query_test: test sum overflow cql3: functions: detect and handle int overflow in sum exceptions: sort exception_code definitions exceptions: define additional cassandra CQL exceptions codes
This commit is contained in:
@@ -44,6 +44,7 @@
|
||||
#include "aggregate_fcts.hh"
|
||||
#include "functions.hh"
|
||||
#include "native_aggregate_function.hh"
|
||||
#include "exceptions/exceptions.hh"
|
||||
|
||||
using namespace cql3;
|
||||
using namespace functions;
|
||||
@@ -75,59 +76,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_sum_function_for final : public aggregate_function::aggregate {
|
||||
Type _sum{};
|
||||
public:
|
||||
virtual void reset() override {
|
||||
_sum = {};
|
||||
}
|
||||
virtual opt_bytes compute(cql_serialization_format sf) override {
|
||||
return data_type_for<Type>()->decompose(_sum);
|
||||
}
|
||||
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
||||
if (!values[0]) {
|
||||
return;
|
||||
}
|
||||
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class sum_function_for final : public native_aggregate_function {
|
||||
public:
|
||||
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
||||
return std::make_unique<impl_sum_function_for<Type>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Type>
|
||||
static
|
||||
shared_ptr<aggregate_function>
|
||||
make_sum_function() {
|
||||
return make_shared<sum_function_for<Type>>();
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
class impl_div_for_avg {
|
||||
public:
|
||||
static Type div(const Type& x, const int64_t y) {
|
||||
return x/y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class impl_div_for_avg<big_decimal> {
|
||||
public:
|
||||
static big_decimal div(const big_decimal& x, const int64_t y) {
|
||||
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
|
||||
}
|
||||
};
|
||||
|
||||
// We need a wider accumulator for average, since summing the inputs can overflow
|
||||
// the input type
|
||||
// We need a wider accumulator for sum and average,
|
||||
// since summing the inputs can overflow the input type
|
||||
template <typename T>
|
||||
struct accumulator_for;
|
||||
|
||||
@@ -171,6 +121,63 @@ struct accumulator_for<big_decimal> {
|
||||
using type = big_decimal;
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_sum_function_for final : public aggregate_function::aggregate {
|
||||
using accumulator_type = typename accumulator_for<Type>::type;
|
||||
accumulator_type _sum{};
|
||||
public:
|
||||
virtual void reset() override {
|
||||
_sum = {};
|
||||
}
|
||||
virtual opt_bytes compute(cql_serialization_format sf) override {
|
||||
Type ret = static_cast<Type>(_sum);
|
||||
if (static_cast<accumulator_type>(ret) != _sum) {
|
||||
throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type.");
|
||||
}
|
||||
return data_type_for<Type>()->decompose(ret);
|
||||
|
||||
}
|
||||
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
||||
if (!values[0]) {
|
||||
return;
|
||||
}
|
||||
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class sum_function_for final : public native_aggregate_function {
|
||||
public:
|
||||
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
||||
return std::make_unique<impl_sum_function_for<Type>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Type>
|
||||
static
|
||||
shared_ptr<aggregate_function>
|
||||
make_sum_function() {
|
||||
return make_shared<sum_function_for<Type>>();
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
class impl_div_for_avg {
|
||||
public:
|
||||
static Type div(const typename accumulator_for<Type>::type& x, const int64_t y) {
|
||||
return x/y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class impl_div_for_avg<big_decimal> {
|
||||
public:
|
||||
static big_decimal div(const big_decimal& x, const int64_t y) {
|
||||
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_avg_function_for final : public aggregate_function::aggregate {
|
||||
typename accumulator_for<Type>::type _sum{};
|
||||
|
||||
@@ -62,9 +62,15 @@ enum class exception_code : int32_t {
|
||||
IS_BOOTSTRAPPING= 0x1002,
|
||||
TRUNCATE_ERROR = 0x1003,
|
||||
WRITE_TIMEOUT = 0x1100,
|
||||
WRITE_FAILURE = 0x1500,
|
||||
READ_TIMEOUT = 0x1200,
|
||||
READ_FAILURE = 0x1300,
|
||||
FUNCTION_FAILURE= 0x1400,
|
||||
WRITE_FAILURE = 0x1500,
|
||||
CDC_WRITE_FAILURE = 0x1600,
|
||||
|
||||
// Scylla-specific codes
|
||||
// Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra.
|
||||
OVERFLOW_ERROR = 0x1aff,
|
||||
|
||||
// 2xx: problem validating the request
|
||||
SYNTAX_ERROR = 0x2000,
|
||||
@@ -203,6 +209,12 @@ struct read_failure_exception : public request_failure_exception {
|
||||
{ }
|
||||
};
|
||||
|
||||
class overflow_error_exception: public cassandra_exception {
|
||||
public:
|
||||
overflow_error_exception(sstring msg) noexcept
|
||||
: cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {}
|
||||
};
|
||||
|
||||
struct overloaded_exception : public cassandra_exception {
|
||||
overloaded_exception(size_t c) noexcept :
|
||||
cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}
|
||||
|
||||
@@ -4313,3 +4313,123 @@ SEASTAR_TEST_CASE(test_rf_expand) {
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_int_sum_overflow) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, ck text, val int, primary key(pk, ck));");
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c1', 2147483647);");
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c2', 1);");
|
||||
auto sum_query = "select sum(val) from cf;";
|
||||
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p2', 'c1', -1);");
|
||||
auto result = e.execute_cql(sum_query).get0();
|
||||
assert_that(result)
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({int32_type->decompose(int32_t(2147483647))});
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c1', 2147483647);");
|
||||
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c2', -2147483648);");
|
||||
result = e.execute_cql(sum_query).get0();
|
||||
assert_that(result)
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({int32_type->decompose(int32_t(2147483646))});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_bigint_sum_overflow) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, ck text, val bigint, primary key(pk, ck));");
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c1', 9223372036854775807);");
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p1', 'c2', 1);");
|
||||
auto sum_query = "select sum(val) from cf;";
|
||||
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p2', 'c1', -1);");
|
||||
auto result = e.execute_cql(sum_query).get0();
|
||||
assert_that(result)
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(9223372036854775807))});
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c1', 9223372036854775807);");
|
||||
BOOST_REQUIRE_THROW(e.execute_cql(sum_query).get(), exceptions::overflow_error_exception);
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, ck, val) values ('p3', 'c2', -9223372036854775808);");
|
||||
result = e.execute_cql(sum_query).get0();
|
||||
assert_that(result)
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(9223372036854775806))});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_bigint_sum) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, val bigint, primary key(pk));");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('x', 2147483647);");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('y', 2147483647);");
|
||||
auto sum_query = "select sum(val) from cf;";
|
||||
assert_that(e.execute_cql(sum_query).get0())
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(4294967294))});
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('z', -4294967295);");
|
||||
assert_that(e.execute_cql(sum_query).get0())
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(-1))});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_int_sum_with_cast) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, val int, primary key(pk));");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('a', 2147483647);");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('b', 2147483647);");
|
||||
auto sum_as_bigint_query = "select sum(val as bigint) from cf;";
|
||||
assert_that(e.execute_cql(sum_as_bigint_query).get0())
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(4294967294))});
|
||||
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('a', -2147483648);");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('b', -2147483647);");
|
||||
assert_that(e.execute_cql(sum_as_bigint_query).get0())
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(-4294967296))});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_int_avg) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, val int, primary key(pk));");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('a', 2147483647);");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('b', 2147483647);");
|
||||
auto result = e.execute_cql("select avg(val) from cf;").get0();
|
||||
assert_that(result)
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({int32_type->decompose(int32_t(2147483647))});
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_bigint_avg) {
|
||||
return do_with_cql_env_thread([] (cql_test_env& e) {
|
||||
cquery_nofail(e, "create table cf (pk text, val bigint, primary key(pk));");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('x', 9223372036854775807);");
|
||||
cquery_nofail(e, "insert into cf (pk, val) values ('y', 9223372036854775807);");
|
||||
assert_that(e.execute_cql("select avg(val) from cf;").get0())
|
||||
.is_rows()
|
||||
.with_size(1)
|
||||
.with_row({long_type->decompose(int64_t(9223372036854775807))});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user