cql3: functions: detect and handle int overflow in sum
Detect integer overflow in cql sum functions and throw an error. Note that Cassandra quietly truncates the sum if it doesn't fit in the input type but we rather break compatibility in this case. See https://issues.apache.org/jira/browse/CASSANDRA-4914?focusedCommentId=14158400&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-14158400 Fixes #5536 Signed-off-by: Benny Halevy <bhalevy@scylladb.com>
This commit is contained in:
@@ -44,6 +44,7 @@
|
||||
#include "aggregate_fcts.hh"
|
||||
#include "functions.hh"
|
||||
#include "native_aggregate_function.hh"
|
||||
#include "exceptions/exceptions.hh"
|
||||
|
||||
using namespace cql3;
|
||||
using namespace functions;
|
||||
@@ -75,59 +76,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_sum_function_for final : public aggregate_function::aggregate {
|
||||
Type _sum{};
|
||||
public:
|
||||
virtual void reset() override {
|
||||
_sum = {};
|
||||
}
|
||||
virtual opt_bytes compute(cql_serialization_format sf) override {
|
||||
return data_type_for<Type>()->decompose(_sum);
|
||||
}
|
||||
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
||||
if (!values[0]) {
|
||||
return;
|
||||
}
|
||||
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class sum_function_for final : public native_aggregate_function {
|
||||
public:
|
||||
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
||||
return std::make_unique<impl_sum_function_for<Type>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Type>
|
||||
static
|
||||
shared_ptr<aggregate_function>
|
||||
make_sum_function() {
|
||||
return make_shared<sum_function_for<Type>>();
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
class impl_div_for_avg {
|
||||
public:
|
||||
static Type div(const Type& x, const int64_t y) {
|
||||
return x/y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class impl_div_for_avg<big_decimal> {
|
||||
public:
|
||||
static big_decimal div(const big_decimal& x, const int64_t y) {
|
||||
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
|
||||
}
|
||||
};
|
||||
|
||||
// We need a wider accumulator for average, since summing the inputs can overflow
|
||||
// the input type
|
||||
// We need a wider accumulator for sum and average,
|
||||
// since summing the inputs can overflow the input type
|
||||
template <typename T>
|
||||
struct accumulator_for;
|
||||
|
||||
@@ -171,6 +121,63 @@ struct accumulator_for<big_decimal> {
|
||||
using type = big_decimal;
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_sum_function_for final : public aggregate_function::aggregate {
|
||||
using accumulator_type = typename accumulator_for<Type>::type;
|
||||
accumulator_type _sum{};
|
||||
public:
|
||||
virtual void reset() override {
|
||||
_sum = {};
|
||||
}
|
||||
virtual opt_bytes compute(cql_serialization_format sf) override {
|
||||
Type ret = static_cast<Type>(_sum);
|
||||
if (static_cast<accumulator_type>(ret) != _sum) {
|
||||
throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type.");
|
||||
}
|
||||
return data_type_for<Type>()->decompose(ret);
|
||||
|
||||
}
|
||||
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
||||
if (!values[0]) {
|
||||
return;
|
||||
}
|
||||
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class sum_function_for final : public native_aggregate_function {
|
||||
public:
|
||||
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
||||
return std::make_unique<impl_sum_function_for<Type>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Type>
|
||||
static
|
||||
shared_ptr<aggregate_function>
|
||||
make_sum_function() {
|
||||
return make_shared<sum_function_for<Type>>();
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
class impl_div_for_avg {
|
||||
public:
|
||||
static Type div(const Type& x, const int64_t y) {
|
||||
return x/y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class impl_div_for_avg<big_decimal> {
|
||||
public:
|
||||
static big_decimal div(const big_decimal& x, const int64_t y) {
|
||||
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class impl_avg_function_for final : public aggregate_function::aggregate {
|
||||
typename accumulator_for<Type>::type _sum{};
|
||||
|
||||
@@ -68,6 +68,10 @@ enum class exception_code : int32_t {
|
||||
WRITE_FAILURE = 0x1500,
|
||||
CDC_WRITE_FAILURE = 0x1600,
|
||||
|
||||
// Scylla-specific codes
|
||||
// Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra.
|
||||
OVERFLOW_ERROR = 0x1aff,
|
||||
|
||||
// 2xx: problem validating the request
|
||||
SYNTAX_ERROR = 0x2000,
|
||||
UNAUTHORIZED = 0x2100,
|
||||
@@ -205,6 +209,12 @@ struct read_failure_exception : public request_failure_exception {
|
||||
{ }
|
||||
};
|
||||
|
||||
class overflow_error_exception: public cassandra_exception {
|
||||
public:
|
||||
overflow_error_exception(sstring msg) noexcept
|
||||
: cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {}
|
||||
};
|
||||
|
||||
struct overloaded_exception : public cassandra_exception {
|
||||
overloaded_exception(size_t c) noexcept :
|
||||
cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}
|
||||
|
||||
Reference in New Issue
Block a user