diff --git a/cql3/functions/aggregate_fcts.cc b/cql3/functions/aggregate_fcts.cc index 26c7e71a74..28d6e5f8a3 100644 --- a/cql3/functions/aggregate_fcts.cc +++ b/cql3/functions/aggregate_fcts.cc @@ -44,6 +44,7 @@ #include "aggregate_fcts.hh" #include "functions.hh" #include "native_aggregate_function.hh" +#include "exceptions/exceptions.hh" using namespace cql3; using namespace functions; @@ -75,59 +76,8 @@ public: } }; -template -class impl_sum_function_for final : public aggregate_function::aggregate { - Type _sum{}; -public: - virtual void reset() override { - _sum = {}; - } - virtual opt_bytes compute(cql_serialization_format sf) override { - return data_type_for()->decompose(_sum); - } - virtual void add_input(cql_serialization_format sf, const std::vector& values) override { - if (!values[0]) { - return; - } - _sum += value_cast(data_type_for()->deserialize(*values[0])); - } -}; - -template -class sum_function_for final : public native_aggregate_function { -public: - sum_function_for() : native_aggregate_function("sum", data_type_for(), { data_type_for() }) {} - virtual std::unique_ptr new_aggregate() override { - return std::make_unique>(); - } -}; - - -template -static -shared_ptr -make_sum_function() { - return make_shared>(); -} - -template -class impl_div_for_avg { -public: - static Type div(const Type& x, const int64_t y) { - return x/y; - } -}; - -template <> -class impl_div_for_avg { -public: - static big_decimal div(const big_decimal& x, const int64_t y) { - return x.div(y, big_decimal::rounding_mode::HALF_EVEN); - } -}; - -// We need a wider accumulator for average, since summing the inputs can overflow -// the input type +// We need a wider accumulator for sum and average, +// since summing the inputs can overflow the input type template struct accumulator_for; @@ -171,6 +121,63 @@ struct accumulator_for { using type = big_decimal; }; +template +class impl_sum_function_for final : public aggregate_function::aggregate { + using accumulator_type = typename accumulator_for::type; + accumulator_type _sum{}; +public: + virtual void reset() override { + _sum = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + Type ret = static_cast(_sum); + if (static_cast(ret) != _sum) { + throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type."); + } + return data_type_for()->decompose(ret); + + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + _sum += value_cast(data_type_for()->deserialize(*values[0])); + } +}; + +template +class sum_function_for final : public native_aggregate_function { +public: + sum_function_for() : native_aggregate_function("sum", data_type_for(), { data_type_for() }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique>(); + } +}; + + +template +static +shared_ptr +make_sum_function() { + return make_shared>(); +} + +template +class impl_div_for_avg { +public: + static Type div(const Type& x, const int64_t y) { + return x/y; + } +}; + +template <> +class impl_div_for_avg { +public: + static big_decimal div(const big_decimal& x, const int64_t y) { + return x.div(y, big_decimal::rounding_mode::HALF_EVEN); + } +}; + template class impl_avg_function_for final : public aggregate_function::aggregate { typename accumulator_for::type _sum{}; diff --git a/exceptions/exceptions.hh b/exceptions/exceptions.hh index c42f73a4b0..7c48b7975a 100644 --- a/exceptions/exceptions.hh +++ b/exceptions/exceptions.hh @@ -68,6 +68,10 @@ enum class exception_code : int32_t { WRITE_FAILURE = 0x1500, CDC_WRITE_FAILURE = 0x1600, + // Scylla-specific codes + // Allocated backwards from 0x1aff-0x1a01 to minimize the chance of collision with Cassandra. + OVERFLOW_ERROR = 0x1aff, + // 2xx: problem validating the request SYNTAX_ERROR = 0x2000, UNAUTHORIZED = 0x2100, @@ -205,6 +209,12 @@ struct read_failure_exception : public request_failure_exception { { } }; +class overflow_error_exception: public cassandra_exception { +public: + overflow_error_exception(sstring msg) noexcept + : cassandra_exception(exception_code::OVERFLOW_ERROR, std::move(msg)) {} +}; + struct overloaded_exception : public cassandra_exception { overloaded_exception(size_t c) noexcept : cassandra_exception(exception_code::OVERLOADED, prepare_message("Too many in flight hints: %lu", c)) {}