diff --git a/cql3/functions/aggregate_fcts.hh b/cql3/functions/aggregate_fcts.hh index 27f05ddbdd..7a0de47a72 100644 --- a/cql3/functions/aggregate_fcts.hh +++ b/cql3/functions/aggregate_fcts.hh @@ -292,6 +292,27 @@ public: } }; +/// The same as `impl_max_function_for' but without knowledge of `Type'. +class impl_max_dynamic_function final : public aggregate_function::aggregate { + opt_bytes _max; +public: + virtual void reset() override { + _max = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + return _max.value_or(bytes{}); + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + const auto val = *values[0]; + if (!_max || *_max < val) { + _max = val; + } + } +}; + template class max_function_for final : public native_aggregate_function { public: @@ -301,6 +322,14 @@ public: } }; +class max_dynamic_function final : public native_aggregate_function { +public: + max_dynamic_function(data_type io_type) : native_aggregate_function("max", io_type, { io_type }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique(); + } +}; + /** * Creates a MAX function for the specified type. * @@ -313,6 +342,12 @@ make_max_function() { return make_shared>(); } +/// The same as `make_max_function()' but with type provided in runtime. +inline shared_ptr +make_max_dynamic_function(data_type io_type) { + return make_shared(io_type); +} + template const Type& min_wrapper(const Type& t1, const Type& t2) { using std::min; @@ -353,6 +388,27 @@ public: } }; +/// The same as `impl_min_function_for' but without knowledge of `Type'. +class impl_min_dynamic_function final : public aggregate_function::aggregate { + opt_bytes _min; +public: + virtual void reset() override { + _min = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + return _min.value_or(bytes{}); + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + const auto val = *values[0]; + if (!_min || val < *_min) { + _min = val; + } + } +}; + template class min_function_for final : public native_aggregate_function { public: @@ -362,6 +418,13 @@ public: } }; +class min_dynamic_function final : public native_aggregate_function { +public: + min_dynamic_function(data_type io_type) : native_aggregate_function("min", io_type, { io_type }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique(); + } +}; /** * Creates a MIN function for the specified type. @@ -375,6 +438,11 @@ make_min_function() { return make_shared>(); } +/// The same as `make_min_function()' but with type provided in runtime. +inline shared_ptr +make_min_dynamic_function(data_type io_type) { + return make_shared(io_type); +} template class impl_count_function_for final : public aggregate_function::aggregate { diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index bf4dafca28..5e3d2c7bdb 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -280,6 +280,8 @@ functions::get(database& db, static const function_name TOKEN_FUNCTION_NAME = function_name::native_function("token"); static const function_name TO_JSON_FUNCTION_NAME = function_name::native_function("tojson"); static const function_name FROM_JSON_FUNCTION_NAME = function_name::native_function("fromjson"); + static const function_name MIN_FUNCTION_NAME = function_name::native_function("min"); + static const function_name MAX_FUNCTION_NAME = function_name::native_function("max"); if (name.has_keyspace() ? name == TOKEN_FUNCTION_NAME @@ -312,6 +314,40 @@ functions::get(database& db, return make_from_json_function(db, keyspace, receiver->type); } + if (name.has_keyspace() + ? name == MIN_FUNCTION_NAME + : name.name == MIN_FUNCTION_NAME.name) { + if (provided_args.size() != 1) { + throw exceptions::invalid_request_exception("min() operates on 1 argument at a time"); + } + selection::selector *sp = dynamic_cast(provided_args[0].get()); + if (!sp) { + throw exceptions::invalid_request_exception("min() is only valid in SELECT clause"); + } + const data_type arg_type = sp->get_type(); + if (arg_type->is_collection() || arg_type->is_tuple() || arg_type->is_user_type()) { + // `min()' function is created on demand for arguments of compound types. + return aggregate_fcts::make_min_dynamic_function(arg_type); + } + } + + if (name.has_keyspace() + ? name == MAX_FUNCTION_NAME + : name.name == MAX_FUNCTION_NAME.name) { + if (provided_args.size() != 1) { + throw exceptions::invalid_request_exception("max() operates on 1 argument at a time"); + } + selection::selector *sp = dynamic_cast(provided_args[0].get()); + if (!sp) { + throw exceptions::invalid_request_exception("max() is only valid in SELECT clause"); + } + const data_type arg_type = sp->get_type(); + if (arg_type->is_collection() || arg_type->is_tuple() || arg_type->is_user_type()) { + // `max()' function is created on demand for arguments of compound types. + return aggregate_fcts::make_max_dynamic_function(arg_type); + } + } + std::vector> candidates; auto&& add_declared = [&] (function_name fn) { auto&& fns = _declared.equal_range(fn); diff --git a/tests/cql_query_test.cc b/tests/cql_query_test.cc index 43d9e59008..11e0197425 100644 --- a/tests/cql_query_test.cc +++ b/tests/cql_query_test.cc @@ -1470,7 +1470,12 @@ struct aggregate_function_test { std::vector _sorted_values; sstring table_name() { - return "cf_" + _column_type->cql3_type_name(); + sstring tbl_name = "cf_" + _column_type->cql3_type_name(); + // Substitute troublesome characters from `cql3_type_name()': + std::for_each(tbl_name.begin(), tbl_name.end(), [] (char& c) { + if (c == '<' || c == '>' || c == ',' || c == ' ') { c = '_'; } + }); + return tbl_name; } void call_function_and_expect(const char* fname, data_type type, data_value expected) { auto msg = _e.execute_cql(format("select {}(value) from {}", fname, table_name())).get0(); @@ -1631,6 +1636,34 @@ SEASTAR_TEST_CASE(test_aggregate_functions) { net::inet_address("1::1"), net::inet_address("1.0.0.1") ).test_min_max_count(); + + auto list_type_int = list_type_impl::get_instance(int32_type, false); + aggregate_function_test(e, list_type_int, + make_list_value(list_type_int, {1, 2, 3}), + make_list_value(list_type_int, {1, 2, 4}), + make_list_value(list_type_int, {2, 2, 3}) + ).test_min_max_count(); + + auto set_type_int = set_type_impl::get_instance(int32_type, false); + aggregate_function_test(e, set_type_int, + make_set_value(set_type_int, {1, 2, 3}), + make_set_value(set_type_int, {1, 2, 4}), + make_set_value(set_type_int, {2, 3, 4}) + ).test_min_max_count(); + + auto tuple_type_int_text = tuple_type_impl::get_instance({int32_type, utf8_type}); + aggregate_function_test(e, tuple_type_int_text, + make_tuple_value(tuple_type_int_text, {1, "aaa"}), + make_tuple_value(tuple_type_int_text, {1, "bbb"}), + make_tuple_value(tuple_type_int_text, {2, "aaa"}) + ).test_min_max_count(); + + auto map_type_int_text = map_type_impl::get_instance(int32_type, utf8_type, false); + aggregate_function_test(e, map_type_int_text, + make_map_value(map_type_int_text, {std::make_pair(data_value(1), data_value("asdf"))}), + make_map_value(map_type_int_text, {std::make_pair(data_value(2), data_value("asdf"))}), + make_map_value(map_type_int_text, {std::make_pair(data_value(2), data_value("bsdf"))}) + ).test_min_max_count(); }); }