diff --git a/cql3/selection/selection.cc b/cql3/selection/selection.cc index d3491fedbd..d557f1be32 100644 --- a/cql3/selection/selection.cc +++ b/cql3/selection/selection.cc @@ -23,6 +23,7 @@ #include "cql3/restrictions/statement_restrictions.hh" #include "cql3/expr/evaluate.hh" #include "cql3/expr/expr-utils.hh" +#include "cql3/functions/first_function.hh" #include "cql3/functions/aggregate_fcts.hh" namespace cql3 { @@ -195,6 +196,9 @@ class selection_with_processing : public selection { private: ::shared_ptr _factories; std::vector _selectors; + std::vector _inner_loop; + std::vector _outer_loop; + std::vector _initial_values_for_temporaries; public: selection_with_processing(schema_ptr schema, std::vector columns, std::vector> metadata, ::shared_ptr factories, @@ -204,13 +208,42 @@ public: contains_ttl(expr::tuple_constructor{selectors})) , _factories(std::move(factories)) , _selectors(std::move(selectors)) - { } + { + auto agg_split = expr::split_aggregation(_selectors); + _outer_loop = std::move(agg_split.outer_loop); + _inner_loop = std::move(agg_split.inner_loop); + _initial_values_for_temporaries = std::move(agg_split.initial_values_for_temporaries); + } virtual uint32_t add_column_for_post_processing(const column_definition& c) override { uint32_t index = selection::add_column_for_post_processing(c); _factories->add_selector_for_post_processing(c, index); _selectors.push_back(expr::column_value(&c)); - return index; + if (_inner_loop.empty()) { + // Simple case: no aggregation + return index; + } else { + // Complex case: aggregation, must pass through temporary + auto first_func = cql3::functions::aggregate_fcts::make_first_function(c.type); + auto& agg = first_func->get_aggregate(); + auto temp_index = _initial_values_for_temporaries.size(); + auto temp = expr::temporary{ + .index = temp_index, + .type = agg.argument_types[0], + }; + _inner_loop.push_back( + expr::function_call{ + .func = agg.aggregation_function, + .args = {temp, expr::column_value(&c)}, + }); + _initial_values_for_temporaries.push_back(raw_value::make_value(agg.initial_state)); + _outer_loop.push_back( + expr::function_call{ + .func = agg.state_to_result_function, + .args = {temp}, + }); + return _outer_loop.size() - 1; + } } virtual bool is_aggregate() const override { @@ -316,12 +349,14 @@ protected: ::shared_ptr _factories; std::vector<::shared_ptr> _selectors; const selection_with_processing& _sel; + std::vector _temporaries; bool _requires_thread; public: selectors_with_processing(const selection_with_processing& sel, ::shared_ptr factories) : _factories(std::move(factories)) , _selectors(_factories->new_instances()) , _sel(sel) + , _temporaries(_sel._initial_values_for_temporaries) , _requires_thread(boost::algorithm::any_of(sel._selectors, [] (const expr::expression& e) { return expr::find_in_expression(e, [] (const expr::function_call& fc) { return std::get>(fc.func)->requires_thread(); @@ -334,9 +369,7 @@ protected: } virtual void reset() override { - for (auto&& s : _selectors) { - s->reset(); - } + _temporaries = _sel._initial_values_for_temporaries; } virtual bool is_aggregate() const override { @@ -365,16 +398,37 @@ protected: virtual std::vector get_output_row() override { std::vector output_row; - output_row.reserve(_selectors.size()); - for (auto&& s : _selectors) { - output_row.emplace_back(s->get_output()); + output_row.reserve(_sel._outer_loop.size()); + auto inputs = expr::evaluation_inputs{ + .partition_key = {}, + .clustering_key = {}, + .static_and_regular_columns = {}, + .selection = &_sel, + .options = nullptr, + .static_and_regular_timestamps = {}, + .static_and_regular_ttls = {}, + .temporaries = _temporaries, + }; + for (auto&& e : _sel._outer_loop) { + auto out = expr::evaluate(e, inputs); + output_row.emplace_back(std::move(out).to_managed_bytes_opt()); } return output_row; } virtual void add_input_row(result_set_builder& rs) override { - for (auto&& s : _selectors) { - s->add_input(rs); + auto inputs = expr::evaluation_inputs{ + .partition_key = rs.current_partition_key, + .clustering_key = rs.current_clustering_key, + .static_and_regular_columns = rs.current, + .selection = &_sel, + .options = nullptr, + .static_and_regular_timestamps = rs._timestamps, + .static_and_regular_ttls = rs._ttls, + .temporaries = _temporaries, + }; + for (size_t i = 0; i != _sel._inner_loop.size(); ++i) { + _temporaries[i] = expr::evaluate(_sel._inner_loop[i], inputs); } }