Merge 'functions: reframe aggregate functions in terms of scalar functions' from Avi Kivity
Currently, aggregate functions are implemented in a statefull manner. The accumulator is stored internally in an aggregate_function::aggregate, requiring each query to instantiate new instances (see aggregate_function_selector's constructor, and note how it's called from selector::new_instance()). This makes aggregates hard to use in expressions, since expressions are stateless (with state only provided to evaluate()). To facilitate migration towards stateless expressions, we define a stateless_aggregate_function (modeled after user-defined aggregates, which are already stateless). This new struct defines the aggregate in terms of three scalar functions: one to aggregate a new input into an accumulator (provided in the first parameter), one to finalize an accumulator into a result, and one to reduce two accumulators for parallelized aggregation. All existing native aggregate functions are converted to the new model, and the old interface is removed. This series does not yet convert selectors to expressions, but it does remove one of the obstacles. Performance evaluation: I created a table with a million ints on a single-node cluster, and ran the avg() function on them. I measured the number of instructions executed with `perf stat -p $(pgrep scylla) -e instructions` while the query was running. The query executed from cache, memtables were flushed beforehand. The instruction count per row increased from roughly 49k to roughly 52k, indicating 3k extra instructions per row. While 3k instructions to execute a function is huge, it is currently dwarfed by other overhead (and will be even less important in a cluster where it CL>1 will cause non-coordinator code to run multiple times). Closes #13105 * github.com:scylladb/scylladb: cql3/selection, forward_service: use use stateless_aggregate_function directly db: functions: fold stateless_aggregate_function_adapter into aggregate_function cql3: functions: simplify accumulator_for template cql3: functions: base user-defined aggregates on stateless aggregates cql3: functions: drop native_aggregate_function cql3: functions: reimplement count(column) statelessly cql3: functions: reimplement avg() statelessly cql3: functions: reimplement sum() statelessly cql3: functions: change wide accumulator type to varint cql3: functions: unreverse types for min/max cql3: functions: rename make_{min,max}_dynamic_function cql3: functions: reimplement min/max statelessly cql3: functions: reimplement count(*) statelessly cql3: functions: simplify creating native functions even more cql3: functions: add helpers for automating marshalling for scalar functions types: fix big_decimal constructor from literal 0 cql3: functions: add helper class for internal scalar functions db: functions: add stateless aggregate functions db, cql3: move scalar_function from cql3/functions to db/functions
This commit is contained in:
@@ -875,6 +875,7 @@ scylla_core = (['message/messaging_service.cc',
|
||||
'db/commitlog/commitlog_replayer.cc',
|
||||
'db/commitlog/commitlog_entry.cc',
|
||||
'db/data_listeners.cc',
|
||||
'db/functions/function.cc',
|
||||
'db/hints/manager.cc',
|
||||
'db/hints/resource_manager.cc',
|
||||
'db/hints/host_filter.cc',
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,11 +27,11 @@ make_count_rows_function();
|
||||
|
||||
/// The same as `make_max_function()' but with type provided in runtime.
|
||||
shared_ptr<aggregate_function>
|
||||
make_max_dynamic_function(data_type io_type);
|
||||
make_max_function(data_type io_type);
|
||||
|
||||
/// The same as `make_min_function()' but with type provided in runtime.
|
||||
shared_ptr<aggregate_function>
|
||||
make_min_dynamic_function(data_type io_type);
|
||||
make_min_function(data_type io_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,10 +274,7 @@ static shared_ptr<function> get_dynamic_aggregate(const function_name &name, con
|
||||
}
|
||||
|
||||
auto& arg = arg_types[0];
|
||||
if (arg->is_collection() || arg->is_tuple() || arg->is_user_type()) {
|
||||
return aggregate_fcts::make_min_dynamic_function(arg);
|
||||
}
|
||||
|
||||
return aggregate_fcts::make_min_function(arg);
|
||||
} else if (name.has_keyspace()
|
||||
? name == MAX_NAME
|
||||
: name.name == MAX_NAME.name) {
|
||||
@@ -287,9 +284,7 @@ static shared_ptr<function> get_dynamic_aggregate(const function_name &name, con
|
||||
}
|
||||
|
||||
auto& arg = arg_types[0];
|
||||
if (arg->is_collection() || arg->is_tuple() || arg->is_user_type()) {
|
||||
return aggregate_fcts::make_max_dynamic_function(arg);
|
||||
}
|
||||
return aggregate_fcts::make_max_function(arg);
|
||||
} else if (name.has_keyspace()
|
||||
? name == COUNT_NAME
|
||||
: name.name == COUNT_NAME.name) {
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
/*
|
||||
* Modified by ScyllaDB
|
||||
*
|
||||
* Copyright (C) 2014-present ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* SPDX-License-Identifier: (AGPL-3.0-or-later and Apache-2.0)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "types/types.hh"
|
||||
#include "native_function.hh"
|
||||
#include "aggregate_function.hh"
|
||||
#include <seastar/core/shared_ptr.hh>
|
||||
|
||||
namespace cql3 {
|
||||
namespace functions {
|
||||
|
||||
/**
|
||||
* Base class for the <code>AggregateFunction</code> native classes.
|
||||
*/
|
||||
class native_aggregate_function : public native_function, public aggregate_function {
|
||||
protected:
|
||||
native_aggregate_function(sstring name, data_type return_type,
|
||||
std::vector<data_type> arg_types)
|
||||
: native_function(std::move(name), std::move(return_type), std::move(arg_types)) {
|
||||
}
|
||||
|
||||
public:
|
||||
virtual bool is_aggregate() const override final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
@@ -10,26 +10,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "bytes.hh"
|
||||
#include "function.hh"
|
||||
#include <vector>
|
||||
#include "db/functions/scalar_function.hh"
|
||||
|
||||
namespace cql3 {
|
||||
|
||||
namespace functions {
|
||||
|
||||
class scalar_function : public virtual function {
|
||||
public:
|
||||
/**
|
||||
* Applies this function to the specified parameter.
|
||||
*
|
||||
* @param parameters the input parameters
|
||||
* @return the result of applying this function to the parameter
|
||||
* @throws InvalidRequestException if this function cannot not be applied to the parameter
|
||||
*/
|
||||
virtual bytes_opt execute(const std::vector<bytes_opt>& parameters) = 0;
|
||||
};
|
||||
|
||||
using scalar_function = db::functions::scalar_function;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,25 +11,15 @@
|
||||
#include "abstract_function.hh"
|
||||
#include "scalar_function.hh"
|
||||
#include "aggregate_function.hh"
|
||||
#include "db/functions/stateless_aggregate_function.hh"
|
||||
#include "data_dictionary/keyspace_element.hh"
|
||||
|
||||
namespace cql3 {
|
||||
namespace functions {
|
||||
|
||||
class user_aggregate : public abstract_function, public aggregate_function, public data_dictionary::keyspace_element {
|
||||
bytes_opt _initcond;
|
||||
::shared_ptr<scalar_function> _sfunc;
|
||||
::shared_ptr<scalar_function> _reducefunc;
|
||||
::shared_ptr<scalar_function> _finalfunc;
|
||||
class user_aggregate : public db::functions::aggregate_function, public data_dictionary::keyspace_element {
|
||||
public:
|
||||
user_aggregate(function_name fname, bytes_opt initcond, ::shared_ptr<scalar_function> sfunc, ::shared_ptr<scalar_function> reducefunc, ::shared_ptr<scalar_function> finalfunc);
|
||||
virtual std::unique_ptr<aggregate_function::aggregate> new_aggregate() override;
|
||||
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() override;
|
||||
virtual bool is_pure() const override;
|
||||
virtual bool is_native() const override;
|
||||
virtual bool is_aggregate() const override;
|
||||
virtual bool is_reducible() const override;
|
||||
virtual bool requires_thread() const override;
|
||||
bool has_finalfunc() const;
|
||||
|
||||
virtual sstring keypace_name() const override { return name().keyspace; }
|
||||
@@ -38,16 +28,16 @@ public:
|
||||
virtual std::ostream& describe(std::ostream& os) const override;
|
||||
|
||||
seastar::shared_ptr<scalar_function> sfunc() const {
|
||||
return _sfunc;
|
||||
return _agg.aggregation_function;
|
||||
}
|
||||
seastar::shared_ptr<scalar_function> reducefunc() const {
|
||||
return _reducefunc;
|
||||
return _agg.state_reduction_function;
|
||||
}
|
||||
seastar::shared_ptr<scalar_function> finalfunc() const {
|
||||
return _finalfunc;
|
||||
return _agg.state_to_result_function;
|
||||
}
|
||||
const bytes_opt& initcond() const {
|
||||
return _initcond;
|
||||
return _agg.initial_state;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -35,11 +35,13 @@ protected:
|
||||
public:
|
||||
static shared_ptr<factory> new_factory(shared_ptr<functions::function> fun, shared_ptr<selector_factories> factories);
|
||||
|
||||
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors)
|
||||
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
|
||||
// an aggregate function's accumulator.
|
||||
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
|
||||
: _fun(std::move(fun)), _arg_selectors(std::move(arg_selectors)),
|
||||
_requires_thread(boost::algorithm::any_of(_arg_selectors, [] (auto& s) { return s->requires_thread(); })
|
||||
|| _fun->requires_thread()) {
|
||||
_args.resize(_arg_selectors.size());
|
||||
_args.resize(_arg_selectors.size() + unsigned(reserve_extra_arg));
|
||||
}
|
||||
|
||||
virtual bool requires_thread() const override;
|
||||
@@ -74,8 +76,10 @@ protected:
|
||||
|
||||
shared_ptr<const T> fun() const { return _tfun; }
|
||||
public:
|
||||
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors)
|
||||
: abstract_function_selector(fun, std::move(arg_selectors))
|
||||
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
|
||||
// an aggregate function's accumulator.
|
||||
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
|
||||
: abstract_function_selector(fun, std::move(arg_selectors), reserve_extra_arg)
|
||||
, _tfun(dynamic_pointer_cast<T>(fun)) {
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ namespace cql3 {
|
||||
namespace selection {
|
||||
|
||||
class aggregate_function_selector : public abstract_function_selector_for<functions::aggregate_function> {
|
||||
std::unique_ptr<functions::aggregate_function::aggregate> _aggregate;
|
||||
const db::functions::stateless_aggregate_function& _aggregate;
|
||||
bytes_opt _accumulator;
|
||||
public:
|
||||
virtual bool is_aggregate() const override {
|
||||
return true;
|
||||
@@ -26,28 +27,32 @@ public:
|
||||
virtual void add_input(result_set_builder& rs) override {
|
||||
// Aggregation of aggregation is not supported
|
||||
size_t m = _arg_selectors.size();
|
||||
_args[0] = std::move(_accumulator);
|
||||
for (size_t i = 0; i < m; ++i) {
|
||||
auto&& s = _arg_selectors[i];
|
||||
s->add_input(rs);
|
||||
_args[i] = s->get_output();
|
||||
_args[i + 1] = s->get_output();
|
||||
s->reset();
|
||||
}
|
||||
_aggregate->add_input(_args);
|
||||
_accumulator = _aggregate.aggregation_function->execute(_args);
|
||||
}
|
||||
|
||||
virtual bytes_opt get_output() override {
|
||||
return _aggregate->compute();
|
||||
return _aggregate.state_to_result_function
|
||||
? _aggregate.state_to_result_function->execute({std::move(_accumulator)})
|
||||
: std::move(_accumulator);
|
||||
}
|
||||
|
||||
virtual void reset() override {
|
||||
_aggregate->reset();
|
||||
_accumulator = _aggregate.initial_state;
|
||||
}
|
||||
|
||||
aggregate_function_selector(shared_ptr<functions::function> func,
|
||||
std::vector<shared_ptr<selector>> arg_selectors)
|
||||
: abstract_function_selector_for<functions::aggregate_function>(
|
||||
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors))
|
||||
, _aggregate(fun()->new_aggregate()) {
|
||||
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors), true)
|
||||
, _aggregate(fun()->get_aggregate())
|
||||
, _accumulator(_aggregate.initial_state) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ target_sources(db
|
||||
commitlog/commitlog_replayer.cc
|
||||
commitlog/commitlog_entry.cc
|
||||
data_listeners.cc
|
||||
functions/function.cc
|
||||
hints/manager.cc
|
||||
hints/resource_manager.cc
|
||||
hints/host_filter.cc
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "function.hh"
|
||||
#include "stateless_aggregate_function.hh"
|
||||
#include <optional>
|
||||
|
||||
namespace db {
|
||||
@@ -21,22 +22,28 @@ namespace functions {
|
||||
* Performs a calculation on a set of values and return a single value.
|
||||
*/
|
||||
class aggregate_function : public virtual function {
|
||||
protected:
|
||||
stateless_aggregate_function _agg;
|
||||
private:
|
||||
shared_ptr<aggregate_function> _reducible;
|
||||
private:
|
||||
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
|
||||
public:
|
||||
class aggregate;
|
||||
explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false);
|
||||
|
||||
/**
|
||||
* Creates a new <code>Aggregate</code> instance.
|
||||
*
|
||||
* @return a new <code>Aggregate</code> instance.
|
||||
*/
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() = 0;
|
||||
const stateless_aggregate_function& get_aggregate() const;
|
||||
|
||||
/**
|
||||
* Checks wheather the function can be distributed and is able to reduce states.
|
||||
*
|
||||
* @return <code>true</code> if the function is reducible, <code>false</code> otherwise.
|
||||
*/
|
||||
virtual bool is_reducible() const = 0;
|
||||
bool is_reducible() const;
|
||||
|
||||
/**
|
||||
* Creates a <code>Aggregate Function</code> that can be reduced.
|
||||
@@ -49,42 +56,17 @@ public:
|
||||
*
|
||||
* @return a reducible <code>Aggregate Function</code>.
|
||||
*/
|
||||
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() = 0;
|
||||
::shared_ptr<aggregate_function> reducible_aggregate_function();
|
||||
|
||||
/**
|
||||
* An aggregation operation.
|
||||
*/
|
||||
class aggregate {
|
||||
public:
|
||||
using opt_bytes = aggregate_function::opt_bytes;
|
||||
|
||||
virtual ~aggregate() {}
|
||||
|
||||
/**
|
||||
* Adds the specified input to this aggregate.
|
||||
*
|
||||
* @param values the values to add to the aggregate.
|
||||
*/
|
||||
virtual void add_input(const std::vector<opt_bytes>& values) = 0;
|
||||
|
||||
/**
|
||||
* Computes and returns the aggregate current value.
|
||||
*
|
||||
* @return the aggregate current value.
|
||||
*/
|
||||
virtual opt_bytes compute() = 0;
|
||||
|
||||
virtual void set_accumulator(const opt_bytes& acc) = 0;
|
||||
|
||||
virtual opt_bytes get_accumulator() const = 0;
|
||||
|
||||
virtual void reduce(const opt_bytes& acc) = 0;
|
||||
|
||||
/**
|
||||
* Reset this aggregate.
|
||||
*/
|
||||
virtual void reset() = 0;
|
||||
};
|
||||
virtual const function_name& name() const override;
|
||||
virtual const std::vector<data_type>& arg_types() const override;
|
||||
virtual const data_type& return_type() const override;
|
||||
virtual bool is_pure() const override;
|
||||
virtual bool is_native() const override;
|
||||
virtual bool requires_thread() const override;
|
||||
virtual bool is_aggregate() const override;
|
||||
virtual void print(std::ostream& os) const override;
|
||||
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
93
db/functions/function.cc
Normal file
93
db/functions/function.cc
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (C) 2023-present ScyllaDB
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
#include "aggregate_function.hh"
|
||||
|
||||
namespace db::functions {
|
||||
|
||||
aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant)
|
||||
: _agg(std::move(agg))
|
||||
, _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) {
|
||||
}
|
||||
|
||||
const stateless_aggregate_function&
|
||||
aggregate_function::get_aggregate() const {
|
||||
return _agg;
|
||||
}
|
||||
|
||||
shared_ptr<aggregate_function>
|
||||
aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
|
||||
if (!agg.state_reduction_function) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_agg = agg;
|
||||
new_agg.state_to_result_function = nullptr;
|
||||
new_agg.result_type = new_agg.aggregation_function->return_type();
|
||||
return make_shared<aggregate_function>(new_agg, true);
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::is_reducible() const {
|
||||
return bool(_agg.state_reduction_function);
|
||||
}
|
||||
|
||||
::shared_ptr<aggregate_function>
|
||||
aggregate_function::reducible_aggregate_function() {
|
||||
return _reducible;
|
||||
}
|
||||
|
||||
const function_name&
|
||||
aggregate_function::name() const {
|
||||
return _agg.name;
|
||||
}
|
||||
|
||||
const std::vector<data_type>&
|
||||
aggregate_function::arg_types() const {
|
||||
return _agg.argument_types;
|
||||
}
|
||||
|
||||
const data_type&
|
||||
aggregate_function::return_type() const {
|
||||
return _agg.result_type;
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::is_pure() const {
|
||||
return _agg.aggregation_function->is_pure()
|
||||
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_pure())
|
||||
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_pure());
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::is_native() const {
|
||||
return _agg.aggregation_function->is_native()
|
||||
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_native())
|
||||
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_native());
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::requires_thread() const {
|
||||
return _agg.aggregation_function->requires_thread()
|
||||
|| (_agg.state_to_result_function && _agg.state_to_result_function->requires_thread())
|
||||
|| (_agg.state_reduction_function && _agg.state_reduction_function->requires_thread());
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::is_aggregate() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
aggregate_function::print(std::ostream& os) const {
|
||||
os << name();
|
||||
}
|
||||
|
||||
sstring
|
||||
aggregate_function::column_name(const std::vector<sstring>& column_names) const {
|
||||
if (_agg.column_name_override) {
|
||||
return *_agg.column_name_override;
|
||||
}
|
||||
return format("{}({})", _agg.name, fmt::join(column_names, ", "));
|
||||
}
|
||||
|
||||
}
|
||||
32
db/functions/scalar_function.hh
Normal file
32
db/functions/scalar_function.hh
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (C) 2014-present ScyllaDB
|
||||
*
|
||||
* Modified by ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* SPDX-License-Identifier: (AGPL-3.0-or-later and Apache-2.0)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "bytes.hh"
|
||||
#include "function.hh"
|
||||
#include <vector>
|
||||
|
||||
namespace db::functions {
|
||||
|
||||
class scalar_function : public virtual function {
|
||||
public:
|
||||
/**
|
||||
* Applies this function to the specified parameter.
|
||||
*
|
||||
* @param parameters the input parameters
|
||||
* @return the result of applying this function to the parameter
|
||||
* @throws InvalidRequestException if this function cannot not be applied to the parameter
|
||||
*/
|
||||
virtual bytes_opt execute(const std::vector<bytes_opt>& parameters) = 0;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
35
db/functions/stateless_aggregate_function.hh
Normal file
35
db/functions/stateless_aggregate_function.hh
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (C) 2023-present ScyllaDB
|
||||
// SPDX-License-Identifier: (AGPL-3.0-or-later and Apache-2.0)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scalar_function.hh"
|
||||
#include "function_name.hh"
|
||||
#include <optional>
|
||||
|
||||
namespace db::functions {
|
||||
|
||||
struct stateless_aggregate_function final {
|
||||
function_name name;
|
||||
std::optional<sstring> column_name_override; // if unset, column name is synthesized from name and argument names
|
||||
|
||||
data_type state_type;
|
||||
data_type result_type;
|
||||
std::vector<data_type> argument_types;
|
||||
|
||||
bytes_opt initial_state;
|
||||
|
||||
// aggregates another input
|
||||
// signature: (state_type, argument_types...) -> state_type
|
||||
shared_ptr<scalar_function> aggregation_function;
|
||||
|
||||
// converts the state type to a result
|
||||
// signature: (state_type) -> result_type
|
||||
shared_ptr<scalar_function> state_to_result_function;
|
||||
|
||||
// optional: reduces states computed in parallel
|
||||
// signature: (state_type, state_type) -> state_type
|
||||
shared_ptr<scalar_function> state_reduction_function;
|
||||
};
|
||||
|
||||
}
|
||||
@@ -60,8 +60,7 @@ static std::vector<::shared_ptr<db::functions::aggregate_function>> get_function
|
||||
class forward_aggregates {
|
||||
private:
|
||||
std::vector<::shared_ptr<db::functions::aggregate_function>> _funcs;
|
||||
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> _aggrs;
|
||||
|
||||
std::vector<db::functions::stateless_aggregate_function> _aggrs;
|
||||
public:
|
||||
forward_aggregates(const query::forward_request& request);
|
||||
void merge(query::forward_result& result, query::forward_result&& other);
|
||||
@@ -85,10 +84,10 @@ public:
|
||||
|
||||
forward_aggregates::forward_aggregates(const query::forward_request& request) {
|
||||
_funcs = get_functions(request);
|
||||
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> aggrs;
|
||||
std::vector<db::functions::stateless_aggregate_function> aggrs;
|
||||
|
||||
for (auto& func: _funcs) {
|
||||
aggrs.push_back(func->new_aggregate());
|
||||
aggrs.push_back(func->get_aggregate());
|
||||
}
|
||||
_aggrs = std::move(aggrs);
|
||||
}
|
||||
@@ -113,9 +112,7 @@ void forward_aggregates::merge(query::forward_result &result, query::forward_res
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
_aggrs[i]->set_accumulator(result.query_results[i]);
|
||||
_aggrs[i]->reduce(std::move(other.query_results[i]));
|
||||
result.query_results[i] = _aggrs[i]->get_accumulator();
|
||||
result.query_results[i] = _aggrs[i].state_reduction_function->execute(std::vector({std::move(result.query_results[i]), std::move(other.query_results[i])}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +123,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
|
||||
// as "WHERE p IN ()". We need to build a fake result with the result
|
||||
// of empty aggregation.
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
result.query_results.push_back(_aggrs[i]->compute());
|
||||
result.query_results.push_back(_aggrs[i].state_to_result_function
|
||||
? _aggrs[i].state_to_result_function->execute(std::vector({_aggrs[i].initial_state}))
|
||||
: _aggrs[i].initial_state);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -141,8 +140,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
_aggrs[i]->set_accumulator(result.query_results[i]);
|
||||
result.query_results[i] = _aggrs[i]->compute();
|
||||
result.query_results[i] = _aggrs[i].state_to_result_function
|
||||
? _aggrs[i].state_to_result_function->execute(std::vector({std::move(result.query_results[i])}))
|
||||
: result.query_results[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <boost/multiprecision/cpp_int.hpp>
|
||||
#include <ostream>
|
||||
#include <compare>
|
||||
#include <concepts>
|
||||
|
||||
#include "bytes.hh"
|
||||
|
||||
@@ -29,6 +30,7 @@ public:
|
||||
explicit big_decimal(sstring_view text);
|
||||
big_decimal();
|
||||
big_decimal(int32_t scale, boost::multiprecision::cpp_int unscaled_value);
|
||||
big_decimal(std::integral auto v) : big_decimal(0, v) {}
|
||||
|
||||
int32_t scale() const { return _scale; }
|
||||
const boost::multiprecision::cpp_int& unscaled_value() const { return _unscaled_value; }
|
||||
|
||||
Reference in New Issue
Block a user