mirror of
https://github.com/scylladb/scylladb.git
synced 2026-04-28 12:17:02 +00:00
Other types do not have a wider accumulator at the moment. And static_cast<accumulator_type>(ret) != _sum evaluates as false for NaN/Inf floating point values. Fixes #5586 Signed-off-by: Benny Halevy <bhalevy@scylladb.com> Message-Id: <20200112183436.77951-1-bhalevy@scylladb.com>
613 lines
18 KiB
C++
613 lines
18 KiB
C++
/*
|
|
* Licensed to the Apache Software Foundation (ASF) under one
|
|
* or more contributor license agreements. See the NOTICE file
|
|
* distributed with this work for additional information
|
|
* regarding copyright ownership. The ASF licenses this file
|
|
* to you under the Apache License, Version 2.0 (the
|
|
* "License"); you may not use this file except in compliance
|
|
* with the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
/*
|
|
* Copyright (C) 2019 ScyllaDB
|
|
*
|
|
* Modified by ScyllaDB
|
|
*/
|
|
|
|
/*
|
|
* This file is part of Scylla.
|
|
*
|
|
* Scylla is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* Scylla is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with Scylla. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
|
|
#include "utils/big_decimal.hh"
|
|
#include "aggregate_fcts.hh"
|
|
#include "functions.hh"
|
|
#include "native_aggregate_function.hh"
|
|
#include "exceptions/exceptions.hh"
|
|
|
|
using namespace cql3;
|
|
using namespace functions;
|
|
using namespace aggregate_fcts;
|
|
|
|
namespace {
|
|
class impl_count_function : public aggregate_function::aggregate {
|
|
int64_t _count;
|
|
public:
|
|
virtual void reset() override {
|
|
_count = 0;
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
return long_type->decompose(_count);
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
++_count;
|
|
}
|
|
};
|
|
|
|
class count_rows_function final : public native_aggregate_function {
|
|
public:
|
|
count_rows_function() : native_aggregate_function(COUNT_ROWS_FUNCTION_NAME, long_type, {}) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_count_function>();
|
|
}
|
|
virtual sstring column_name(const std::vector<sstring>& column_names) const override {
|
|
return "count";
|
|
}
|
|
};
|
|
|
|
// We need a wider accumulator for sum and average,
|
|
// since summing the inputs can overflow the input type
|
|
template <typename T>
|
|
struct accumulator_for;
|
|
|
|
template <typename NarrowType, typename AccType>
|
|
static NarrowType checking_narrow(AccType acc) {
|
|
NarrowType ret = static_cast<NarrowType>(acc);
|
|
if (static_cast<AccType>(ret) != acc) {
|
|
throw exceptions::overflow_error_exception("Sum overflow. Values should be casted to a wider type.");
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <>
|
|
struct accumulator_for<int8_t> {
|
|
using type = __int128;
|
|
|
|
static int8_t narrow(type acc) {
|
|
return checking_narrow<int8_t>(acc);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<int16_t> {
|
|
using type = __int128;
|
|
|
|
static int16_t narrow(type acc) {
|
|
return checking_narrow<int16_t>(acc);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<int32_t> {
|
|
using type = __int128;
|
|
|
|
static int32_t narrow(type acc) {
|
|
return checking_narrow<int32_t>(acc);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<int64_t> {
|
|
using type = __int128;
|
|
|
|
static int64_t narrow(type acc) {
|
|
return checking_narrow<int64_t>(acc);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<float> {
|
|
using type = float;
|
|
|
|
static auto narrow(type acc) {
|
|
return acc;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<double> {
|
|
using type = double;
|
|
|
|
static auto narrow(type acc) {
|
|
return acc;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<boost::multiprecision::cpp_int> {
|
|
using type = boost::multiprecision::cpp_int;
|
|
|
|
static auto narrow(type acc) {
|
|
return acc;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct accumulator_for<big_decimal> {
|
|
using type = big_decimal;
|
|
|
|
static auto narrow(type acc) {
|
|
return acc;
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class impl_sum_function_for final : public aggregate_function::aggregate {
|
|
using accumulator_type = typename accumulator_for<Type>::type;
|
|
accumulator_type _sum{};
|
|
public:
|
|
virtual void reset() override {
|
|
_sum = {};
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
return data_type_for<Type>()->decompose(accumulator_for<Type>::narrow(_sum));
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class sum_function_for final : public native_aggregate_function {
|
|
public:
|
|
sum_function_for() : native_aggregate_function("sum", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_sum_function_for<Type>>();
|
|
}
|
|
};
|
|
|
|
|
|
template <typename Type>
|
|
static
|
|
shared_ptr<aggregate_function>
|
|
make_sum_function() {
|
|
return make_shared<sum_function_for<Type>>();
|
|
}
|
|
|
|
template <typename Type>
|
|
class impl_div_for_avg {
|
|
public:
|
|
static Type div(const typename accumulator_for<Type>::type& x, const int64_t y) {
|
|
return x/y;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
class impl_div_for_avg<big_decimal> {
|
|
public:
|
|
static big_decimal div(const big_decimal& x, const int64_t y) {
|
|
return x.div(y, big_decimal::rounding_mode::HALF_EVEN);
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class impl_avg_function_for final : public aggregate_function::aggregate {
|
|
typename accumulator_for<Type>::type _sum{};
|
|
int64_t _count = 0;
|
|
public:
|
|
virtual void reset() override {
|
|
_sum = {};
|
|
_count = 0;
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
Type ret{};
|
|
if (_count) {
|
|
ret = impl_div_for_avg<Type>::div(_sum, _count);
|
|
}
|
|
return data_type_for<Type>()->decompose(ret);
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
++_count;
|
|
_sum += value_cast<Type>(data_type_for<Type>()->deserialize(*values[0]));
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class avg_function_for final : public native_aggregate_function {
|
|
public:
|
|
avg_function_for() : native_aggregate_function("avg", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_avg_function_for<Type>>();
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
static
|
|
shared_ptr<aggregate_function>
|
|
make_avg_function() {
|
|
return make_shared<avg_function_for<Type>>();
|
|
}
|
|
|
|
template <typename T>
|
|
struct aggregate_type_for {
|
|
using type = T;
|
|
};
|
|
|
|
template<>
|
|
struct aggregate_type_for<ascii_native_type> {
|
|
using type = ascii_native_type::primary_type;
|
|
};
|
|
|
|
template<>
|
|
struct aggregate_type_for<simple_date_native_type> {
|
|
using type = simple_date_native_type::primary_type;
|
|
};
|
|
|
|
template<>
|
|
struct aggregate_type_for<timeuuid_native_type> {
|
|
using type = timeuuid_native_type::primary_type;
|
|
};
|
|
|
|
template<>
|
|
struct aggregate_type_for<time_native_type> {
|
|
using type = time_native_type::primary_type;
|
|
};
|
|
|
|
template <typename Type>
|
|
const Type& max_wrapper(const Type& t1, const Type& t2) {
|
|
using std::max;
|
|
return max(t1, t2);
|
|
}
|
|
|
|
inline const net::inet_address& max_wrapper(const net::inet_address& t1, const net::inet_address& t2) {
|
|
using family = seastar::net::inet_address::family;
|
|
const size_t len =
|
|
(t1.in_family() == family::INET || t2.in_family() == family::INET)
|
|
? sizeof(::in_addr) : sizeof(::in6_addr);
|
|
return std::memcmp(t1.data(), t2.data(), len) >= 0 ? t1 : t2;
|
|
}
|
|
|
|
template <typename Type>
|
|
class impl_max_function_for final : public aggregate_function::aggregate {
|
|
std::optional<typename aggregate_type_for<Type>::type> _max{};
|
|
public:
|
|
virtual void reset() override {
|
|
_max = {};
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
if (!_max) {
|
|
return {};
|
|
}
|
|
return data_type_for<Type>()->decompose(data_value(Type{*_max}));
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
auto val = value_cast<typename aggregate_type_for<Type>::type>(data_type_for<Type>()->deserialize(*values[0]));
|
|
if (!_max) {
|
|
_max = val;
|
|
} else {
|
|
_max = max_wrapper(*_max, val);
|
|
}
|
|
}
|
|
};
|
|
|
|
/// The same as `impl_max_function_for' but without knowledge of `Type'.
|
|
class impl_max_dynamic_function final : public aggregate_function::aggregate {
|
|
opt_bytes _max;
|
|
public:
|
|
virtual void reset() override {
|
|
_max = {};
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
return _max.value_or(bytes{});
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
const auto val = *values[0];
|
|
if (!_max || *_max < val) {
|
|
_max = val;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class max_function_for final : public native_aggregate_function {
|
|
public:
|
|
max_function_for() : native_aggregate_function("max", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_max_function_for<Type>>();
|
|
}
|
|
};
|
|
|
|
class max_dynamic_function final : public native_aggregate_function {
|
|
public:
|
|
max_dynamic_function(data_type io_type) : native_aggregate_function("max", io_type, { io_type }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_max_dynamic_function>();
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Creates a MAX function for the specified type.
|
|
*
|
|
* @param inputType the function input and output type
|
|
* @return a MAX function for the specified type.
|
|
*/
|
|
template <typename Type>
|
|
static
|
|
shared_ptr<aggregate_function>
|
|
make_max_function() {
|
|
return make_shared<max_function_for<Type>>();
|
|
}
|
|
|
|
template <typename Type>
|
|
const Type& min_wrapper(const Type& t1, const Type& t2) {
|
|
using std::min;
|
|
return min(t1, t2);
|
|
}
|
|
|
|
inline const net::inet_address& min_wrapper(const net::inet_address& t1, const net::inet_address& t2) {
|
|
using family = seastar::net::inet_address::family;
|
|
const size_t len =
|
|
(t1.in_family() == family::INET || t2.in_family() == family::INET)
|
|
? sizeof(::in_addr) : sizeof(::in6_addr);
|
|
return std::memcmp(t1.data(), t2.data(), len) <= 0 ? t1 : t2;
|
|
}
|
|
|
|
template <typename Type>
|
|
class impl_min_function_for final : public aggregate_function::aggregate {
|
|
std::optional<typename aggregate_type_for<Type>::type> _min{};
|
|
public:
|
|
virtual void reset() override {
|
|
_min = {};
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
if (!_min) {
|
|
return {};
|
|
}
|
|
return data_type_for<Type>()->decompose(data_value(Type{*_min}));
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
auto val = value_cast<typename aggregate_type_for<Type>::type>(data_type_for<Type>()->deserialize(*values[0]));
|
|
if (!_min) {
|
|
_min = val;
|
|
} else {
|
|
_min = min_wrapper(*_min, val);
|
|
}
|
|
}
|
|
};
|
|
|
|
/// The same as `impl_min_function_for' but without knowledge of `Type'.
|
|
class impl_min_dynamic_function final : public aggregate_function::aggregate {
|
|
opt_bytes _min;
|
|
public:
|
|
virtual void reset() override {
|
|
_min = {};
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
return _min.value_or(bytes{});
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
const auto val = *values[0];
|
|
if (!_min || val < *_min) {
|
|
_min = val;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class min_function_for final : public native_aggregate_function {
|
|
public:
|
|
min_function_for() : native_aggregate_function("min", data_type_for<Type>(), { data_type_for<Type>() }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_min_function_for<Type>>();
|
|
}
|
|
};
|
|
|
|
class min_dynamic_function final : public native_aggregate_function {
|
|
public:
|
|
min_dynamic_function(data_type io_type) : native_aggregate_function("min", io_type, { io_type }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_min_dynamic_function>();
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Creates a MIN function for the specified type.
|
|
*
|
|
* @param inputType the function input and output type
|
|
* @return a MIN function for the specified type.
|
|
*/
|
|
template <typename Type>
|
|
static
|
|
shared_ptr<aggregate_function>
|
|
make_min_function() {
|
|
return make_shared<min_function_for<Type>>();
|
|
}
|
|
|
|
template <typename Type>
|
|
class impl_count_function_for final : public aggregate_function::aggregate {
|
|
int64_t _count = 0;
|
|
public:
|
|
virtual void reset() override {
|
|
_count = 0;
|
|
}
|
|
virtual opt_bytes compute(cql_serialization_format sf) override {
|
|
return long_type->decompose(_count);
|
|
}
|
|
virtual void add_input(cql_serialization_format sf, const std::vector<opt_bytes>& values) override {
|
|
if (!values[0]) {
|
|
return;
|
|
}
|
|
++_count;
|
|
}
|
|
};
|
|
|
|
template <typename Type>
|
|
class count_function_for final : public native_aggregate_function {
|
|
public:
|
|
count_function_for() : native_aggregate_function("count", long_type, { data_type_for<Type>() }) {}
|
|
virtual std::unique_ptr<aggregate> new_aggregate() override {
|
|
return std::make_unique<impl_count_function_for<Type>>();
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Creates a COUNT function for the specified type.
|
|
*
|
|
* @param inputType the function input type
|
|
* @return a COUNT function for the specified type.
|
|
*/
|
|
template <typename Type>
|
|
static shared_ptr<aggregate_function> make_count_function() {
|
|
return make_shared<count_function_for<Type>>();
|
|
}
|
|
}
|
|
|
|
shared_ptr<aggregate_function>
|
|
aggregate_fcts::make_count_rows_function() {
|
|
return make_shared<count_rows_function>();
|
|
}
|
|
|
|
shared_ptr<aggregate_function>
|
|
aggregate_fcts::make_max_dynamic_function(data_type io_type) {
|
|
return make_shared<max_dynamic_function>(io_type);
|
|
}
|
|
|
|
shared_ptr<aggregate_function>
|
|
aggregate_fcts::make_min_dynamic_function(data_type io_type) {
|
|
return make_shared<min_dynamic_function>(io_type);
|
|
}
|
|
|
|
void cql3::functions::add_agg_functions(declared_t& funcs) {
|
|
auto declare = [&funcs] (shared_ptr<function> f) { funcs.emplace(f->name(), f); };
|
|
|
|
declare(make_count_function<int8_t>());
|
|
declare(make_max_function<int8_t>());
|
|
declare(make_min_function<int8_t>());
|
|
|
|
declare(make_count_function<int16_t>());
|
|
declare(make_max_function<int16_t>());
|
|
declare(make_min_function<int16_t>());
|
|
|
|
declare(make_count_function<int32_t>());
|
|
declare(make_max_function<int32_t>());
|
|
declare(make_min_function<int32_t>());
|
|
|
|
declare(make_count_function<int64_t>());
|
|
declare(make_max_function<int64_t>());
|
|
declare(make_min_function<int64_t>());
|
|
|
|
declare(make_count_function<boost::multiprecision::cpp_int>());
|
|
declare(make_max_function<boost::multiprecision::cpp_int>());
|
|
declare(make_min_function<boost::multiprecision::cpp_int>());
|
|
|
|
declare(make_count_function<big_decimal>());
|
|
declare(make_max_function<big_decimal>());
|
|
declare(make_min_function<big_decimal>());
|
|
|
|
declare(make_count_function<float>());
|
|
declare(make_max_function<float>());
|
|
declare(make_min_function<float>());
|
|
|
|
declare(make_count_function<double>());
|
|
declare(make_max_function<double>());
|
|
declare(make_min_function<double>());
|
|
|
|
declare(make_count_function<sstring>());
|
|
declare(make_max_function<sstring>());
|
|
declare(make_min_function<sstring>());
|
|
|
|
declare(make_count_function<ascii_native_type>());
|
|
declare(make_max_function<ascii_native_type>());
|
|
declare(make_min_function<ascii_native_type>());
|
|
|
|
declare(make_count_function<simple_date_native_type>());
|
|
declare(make_max_function<simple_date_native_type>());
|
|
declare(make_min_function<simple_date_native_type>());
|
|
|
|
declare(make_count_function<db_clock::time_point>());
|
|
declare(make_max_function<db_clock::time_point>());
|
|
declare(make_min_function<db_clock::time_point>());
|
|
|
|
declare(make_count_function<timeuuid_native_type>());
|
|
declare(make_max_function<timeuuid_native_type>());
|
|
declare(make_min_function<timeuuid_native_type>());
|
|
|
|
declare(make_count_function<time_native_type>());
|
|
declare(make_max_function<time_native_type>());
|
|
declare(make_min_function<time_native_type>());
|
|
|
|
declare(make_count_function<utils::UUID>());
|
|
declare(make_max_function<utils::UUID>());
|
|
declare(make_min_function<utils::UUID>());
|
|
|
|
declare(make_count_function<bytes>());
|
|
declare(make_max_function<bytes>());
|
|
declare(make_min_function<bytes>());
|
|
|
|
declare(make_count_function<bool>());
|
|
declare(make_max_function<bool>());
|
|
declare(make_min_function<bool>());
|
|
|
|
declare(make_count_function<net::inet_address>());
|
|
declare(make_max_function<net::inet_address>());
|
|
declare(make_min_function<net::inet_address>());
|
|
|
|
// FIXME: more count/min/max
|
|
|
|
declare(make_sum_function<int8_t>());
|
|
declare(make_sum_function<int16_t>());
|
|
declare(make_sum_function<int32_t>());
|
|
declare(make_sum_function<int64_t>());
|
|
declare(make_sum_function<float>());
|
|
declare(make_sum_function<double>());
|
|
declare(make_sum_function<boost::multiprecision::cpp_int>());
|
|
declare(make_sum_function<big_decimal>());
|
|
declare(make_avg_function<int8_t>());
|
|
declare(make_avg_function<int16_t>());
|
|
declare(make_avg_function<int32_t>());
|
|
declare(make_avg_function<int64_t>());
|
|
declare(make_avg_function<float>());
|
|
declare(make_avg_function<double>());
|
|
declare(make_avg_function<boost::multiprecision::cpp_int>());
|
|
declare(make_avg_function<big_decimal>());
|
|
}
|