cql3: functions: detect and handle int overflow in sum

Detect integer overflow in cql sum functions and throw an error.
Note that Cassandra quietly truncates the sum if it doesn't fit
in the input type but we rather break compatibility in this
case. See https://issues.apache.org/jira/browse/CASSANDRA-4914?focusedCommentId=14158400&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-14158400

Fixes #5536

Signed-off-by: Benny Halevy <bhalevy@scylladb.com>
This commit is contained in:
Benny Halevy
2019-12-31 15:50:34 +02:00
parent 98260254df
commit e97a111f64
2 changed files with 70 additions and 53 deletions

View File

@@ -44,6 +44,7 @@
#include "aggregate_fcts.hh"
#include "functions.hh"
#include "native_aggregate_function.hh"
#include "exceptions/exceptions.hh"
using namespace cql3;
using namespace functions;
@@ -75,59 +76,8 @@ public:
}
};
template <typename Type>
class impl_sum_function_for final : public aggregate_function::aggregate {
Type _sum{};
public:
virtual void reset() override {
_sum = {};
}
virtual opt_bytes compute(cql_serialization_format sf) override {
return data_type_for<Type>()->decompose(_sum);
}
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
if (!values[0]) {
return;
}
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
}
};
template <typename Type>
class sum_function_for final : public native_aggregate_function {
public:
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
virtual std::unique_ptr<aggregate> new_aggregate() override {
return std::make_unique<impl_sum_function_for<Type>>();
}
};
template <typename Type>
static
shared_ptr<aggregate_function>
make_sum_function() {
return make_shared<sum_function_for<Type>>();
}
template <typename Type>
class impl_div_for_avg {
public:
static Type div(const Type& x, const int64_t y) {
return x/y;
}
};
template <>
class impl_div_for_avg<big_decimal> {
public:
static big_decimal div(const big_decimal& x, const int64_t y) {
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
}
};
// We need a wider accumulator for average, since summing the inputs can overflow
// the input type
// We need a wider accumulator for sum and average,
// since summing the inputs can overflow the input type
template <typename T>
struct accumulator_for;
@@ -171,6 +121,63 @@ struct accumulator_for<big_decimal> {
using type = big_decimal;
};
template <typename Type>
class impl_sum_function_for final : public aggregate_function::aggregate {
using accumulator_type = typename accumulator_for<Type>::type;
accumulator_type _sum{};
public:
virtual void reset() override {
_sum = {};
}
virtual opt_bytes compute(cql_serialization_format sf) override {
Type ret = static_cast<Type>(_sum);
if (static_cast<accumulator_type>(ret) != _sum) {
throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type.");
}
return data_type_for<Type>()->decompose(ret);
}
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
if (!values[0]) {
return;
}
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
}
};
template <typename Type>
class sum_function_for final : public native_aggregate_function {
public:
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
virtual std::unique_ptr<aggregate> new_aggregate() override {
return std::make_unique<impl_sum_function_for<Type>>();
}
};
template <typename Type>
static
shared_ptr<aggregate_function>
make_sum_function() {
return make_shared<sum_function_for<Type>>();
}
template <typename Type>
class impl_div_for_avg {
public:
static Type div(const Type& x, const int64_t y) {
return x/y;
}
};
template <>
class impl_div_for_avg<big_decimal> {
public:
static big_decimal div(const big_decimal& x, const int64_t y) {
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
}
};
template <typename Type>
class impl_avg_function_for final : public aggregate_function::aggregate {
typename accumulator_for<Type>::type _sum{};

View File

@@ -68,6 +68,10 @@ enum class exception_code : int32_t {
WRITE_FAILURE = 0x1500,
CDC_WRITE_FAILURE = 0x1600,
// Scylla-specific codes
// Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra.
OVERFLOW_ERROR = 0x1aff,
// 2xx: problem validating the request
SYNTAX_ERROR = 0x2000,
UNAUTHORIZED = 0x2100,
@@ -205,6 +209,12 @@ struct read_failure_exception : public request_failure_exception {
{ }
};
class overflow_error_exception: public cassandra_exception {
public:
overflow_error_exception(sstring msg) noexcept
: cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {}
};
struct overloaded_exception : public cassandra_exception {
overloaded_exception(size_t c) noexcept :
cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}