From e97a111f642aff2aa1241a9d4261edfd97e545c1 Mon Sep 17 00:00:00 2001 From: Benny Halevy Date: Tue, 31 Dec 2019 15:50:34 +0200 Subject: [PATCH] cql3: functions: detect and handle int overflow in sum Detect integer overflow in cql sum functions and throw an error. Note that Cassandra quietly truncates the sum if it doesn't fit in the input type but we rather break compatibility in this case. See https://issues.apache.org/jira/browse/CASSANDRA-4914?focusedCommentId=14158400&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-14158400 Fixes #5536 Signed-off-by: Benny Halevy --- cql3/functions/aggregate_fcts.cc | 113 ++++++++++++++++--------------- exceptions/exceptions.hh | 10 +++ 2 files changed, 70 insertions(+), 53 deletions(-) diff --git a/cql3/functions/aggregate_fcts.cc b/cql3/functions/aggregate_fcts.cc index 26c7e71a74..28d6e5f8a3 100644 --- a/cql3/functions/aggregate_fcts.cc +++ b/cql3/functions/aggregate_fcts.cc @@ -44,6 +44,7 @@ #include "aggregate_fcts.hh" #include "functions.hh" #include "native_aggregate_function.hh" +#include "exceptions/exceptions.hh" using namespace cql3; using namespace functions; @@ -75,59 +76,8 @@ public: } }; -template -class impl_sum_function_for final : public aggregate_function::aggregate { - Type _sum{}; -public: - virtual void reset() override { - _sum = {}; - } - virtual opt_bytes compute(cql_serialization_format sf) override { - return data_type_for()->decompose(_sum); - } - virtual void add_input(cql_serialization_format sf, const std::vector& values) override { - if (!values[0]) { - return; - } - _sum += value_cast(data_type_for()->deserialize(*values[0])); - } -}; - -template -class sum_function_for final : public native_aggregate_function { -public: - sum_function_for() : native_aggregate_function("sum", data_type_for(), { data_type_for() }) {} - virtual std::unique_ptr new_aggregate() override { - return std::make_unique>(); - } -}; - - -template -static -shared_ptr -make_sum_function() { - return make_shared>(); -} - -template -class impl_div_for_avg { -public: - static Type div(const Type& x, const int64_t y) { - return x/y; - } -}; - -template <> -class impl_div_for_avg { -public: - static big_decimal div(const big_decimal& x, const int64_t y) { - return x.div(y, big_decimal::rounding_mode::HALF_EVEN); - } -}; - -// We need a wider accumulator for average, since summing the inputs can overflow -// the input type +// We need a wider accumulator for sum and average, +// since summing the inputs can overflow the input type template struct accumulator_for; @@ -171,6 +121,63 @@ struct accumulator_for { using type = big_decimal; }; +template +class impl_sum_function_for final : public aggregate_function::aggregate { + using accumulator_type = typename accumulator_for::type; + accumulator_type _sum{}; +public: + virtual void reset() override { + _sum = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + Type ret = static_cast(_sum); + if (static_cast(ret) != _sum) { + throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type."); + } + return data_type_for()->decompose(ret); + + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + _sum += value_cast(data_type_for()->deserialize(*values[0])); + } +}; + +template +class sum_function_for final : public native_aggregate_function { +public: + sum_function_for() : native_aggregate_function("sum", data_type_for(), { data_type_for() }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique>(); + } +}; + + +template +static +shared_ptr +make_sum_function() { + return make_shared>(); +} + +template +class impl_div_for_avg { +public: + static Type div(const Type& x, const int64_t y) { + return x/y; + } +}; + +template <> +class impl_div_for_avg { +public: + static big_decimal div(const big_decimal& x, const int64_t y) { + return x.div(y, big_decimal::rounding_mode::HALF_EVEN); + } +}; + template class impl_avg_function_for final : public aggregate_function::aggregate { typename accumulator_for::type _sum{}; diff --git a/exceptions/exceptions.hh b/exceptions/exceptions.hh index c42f73a4b0..7c48b7975a 100644 --- a/exceptions/exceptions.hh +++ b/exceptions/exceptions.hh @@ -68,6 +68,10 @@ enum class exception_code : int32_t { WRITE_FAILURE = 0x1500, CDC_WRITE_FAILURE = 0x1600, + // Scylla-specific codes + // Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra. + OVERFLOW_ERROR = 0x1aff, + // 2xx: problem validating the request SYNTAX_ERROR = 0x2000, UNAUTHORIZED = 0x2100, @@ -205,6 +209,12 @@ struct read_failure_exception : public request_failure_exception { { } }; +class overflow_error_exception: public cassandra_exception { +public: + overflow_error_exception(sstring msg) noexcept + : cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {} +}; + struct overloaded_exception : public cassandra_exception { overloaded_exception(size_t c) noexcept : cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}