From 9efad36fb8a6e280c6f36b798923b13edc42e001 Mon Sep 17 00:00:00 2001 From: Juliusz Stasiewicz Date: Thu, 5 Dec 2019 13:19:55 +0100 Subject: [PATCH 1/2] cql3: min()/max() for collections/tuples/UDTs do not cast to blobs Before: cqlsh> insert into ks.list_types (id, val) values (1, [3,4,5]); cqlsh> select max(val) from ks.list_types; system.max(val) ------------------------------------------------------------ 0x00000003000000040000000300000004000000040000000400000005 After: cqlsh> select max(val) from ks.list_types; system.max(val) -------------------- [3, 4, 5] This is accomplished similarly to `tojson()`/`fromjson()`: functions are generated on demand from within `cql3::functions::get()`. Because collections can have a variety of types, including UDTs and tuples, it would be impossible to statically define max(T t)->T for every T. Until now, max(blob)->blob overload was used. Because `impl_max/min_function_for` is templated with the input/output type, which can be defined in runtime, we need type-erased ("dynamic") versions of these functors. They work identically, i.e. they compare byte representations of lhs and rhs with `bytes::operator<`. Resolves #5139 --- cql3/functions/aggregate_fcts.hh | 68 ++++++++++++++++++++++++++++++++ cql3/functions/functions.cc | 36 +++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/cql3/functions/aggregate_fcts.hh b/cql3/functions/aggregate_fcts.hh index 27f05ddbdd..7a0de47a72 100644 --- a/cql3/functions/aggregate_fcts.hh +++ b/cql3/functions/aggregate_fcts.hh @@ -292,6 +292,27 @@ public: } }; +/// The same as `impl_max_function_for' but without knowledge of `Type'. +class impl_max_dynamic_function final : public aggregate_function::aggregate { + opt_bytes _max; +public: + virtual void reset() override { + _max = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + return _max.value_or(bytes{}); + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + const auto val = *values[0]; + if (!_max || *_max < val) { + _max = val; + } + } +}; + template class max_function_for final : public native_aggregate_function { public: @@ -301,6 +322,14 @@ public: } }; +class max_dynamic_function final : public native_aggregate_function { +public: + max_dynamic_function(data_type io_type) : native_aggregate_function("max", io_type, { io_type }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique(); + } +}; + /** * Creates a MAX function for the specified type. * @@ -313,6 +342,12 @@ make_max_function() { return make_shared>(); } +/// The same as `make_max_function()' but with type provided in runtime. +inline shared_ptr +make_max_dynamic_function(data_type io_type) { + return make_shared(io_type); +} + template const Type& min_wrapper(const Type& t1, const Type& t2) { using std::min; @@ -353,6 +388,27 @@ public: } }; +/// The same as `impl_min_function_for' but without knowledge of `Type'. +class impl_min_dynamic_function final : public aggregate_function::aggregate { + opt_bytes _min; +public: + virtual void reset() override { + _min = {}; + } + virtual opt_bytes compute(cql_serialization_format sf) override { + return _min.value_or(bytes{}); + } + virtual void add_input(cql_serialization_format sf, const std::vector& values) override { + if (!values[0]) { + return; + } + const auto val = *values[0]; + if (!_min || val < *_min) { + _min = val; + } + } +}; + template class min_function_for final : public native_aggregate_function { public: @@ -362,6 +418,13 @@ public: } }; +class min_dynamic_function final : public native_aggregate_function { +public: + min_dynamic_function(data_type io_type) : native_aggregate_function("min", io_type, { io_type }) {} + virtual std::unique_ptr new_aggregate() override { + return std::make_unique(); + } +}; /** * Creates a MIN function for the specified type. @@ -375,6 +438,11 @@ make_min_function() { return make_shared>(); } +/// The same as `make_min_function()' but with type provided in runtime. +inline shared_ptr +make_min_dynamic_function(data_type io_type) { + return make_shared(io_type); +} template class impl_count_function_for final : public aggregate_function::aggregate { diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index bf4dafca28..5e3d2c7bdb 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -280,6 +280,8 @@ functions::get(database& db, static const function_name TOKEN_FUNCTION_NAME = function_name::native_function("token"); static const function_name TO_JSON_FUNCTION_NAME = function_name::native_function("tojson"); static const function_name FROM_JSON_FUNCTION_NAME = function_name::native_function("fromjson"); + static const function_name MIN_FUNCTION_NAME = function_name::native_function("min"); + static const function_name MAX_FUNCTION_NAME = function_name::native_function("max"); if (name.has_keyspace() ? name == TOKEN_FUNCTION_NAME @@ -312,6 +314,40 @@ functions::get(database& db, return make_from_json_function(db, keyspace, receiver->type); } + if (name.has_keyspace() + ? name == MIN_FUNCTION_NAME + : name.name == MIN_FUNCTION_NAME.name) { + if (provided_args.size() != 1) { + throw exceptions::invalid_request_exception("min() operates on 1 argument at a time"); + } + selection::selector *sp = dynamic_cast(provided_args[0].get()); + if (!sp) { + throw exceptions::invalid_request_exception("min() is only valid in SELECT clause"); + } + const data_type arg_type = sp->get_type(); + if (arg_type->is_collection() || arg_type->is_tuple() || arg_type->is_user_type()) { + // `min()' function is created on demand for arguments of compound types. + return aggregate_fcts::make_min_dynamic_function(arg_type); + } + } + + if (name.has_keyspace() + ? name == MAX_FUNCTION_NAME + : name.name == MAX_FUNCTION_NAME.name) { + if (provided_args.size() != 1) { + throw exceptions::invalid_request_exception("max() operates on 1 argument at a time"); + } + selection::selector *sp = dynamic_cast(provided_args[0].get()); + if (!sp) { + throw exceptions::invalid_request_exception("max() is only valid in SELECT clause"); + } + const data_type arg_type = sp->get_type(); + if (arg_type->is_collection() || arg_type->is_tuple() || arg_type->is_user_type()) { + // `max()' function is created on demand for arguments of compound types. + return aggregate_fcts::make_max_dynamic_function(arg_type); + } + } + std::vector> candidates; auto&& add_declared = [&] (function_name fn) { auto&& fns = _declared.equal_range(fn); From 75955beb0b4c5c6368fbc3968f1df6a28474334a Mon Sep 17 00:00:00 2001 From: Juliusz Stasiewicz Date: Thu, 5 Dec 2019 17:43:13 +0100 Subject: [PATCH 2/2] cql_query_tests: Added tests for min/max/count on collections This tests new min/max function for collections and tuples. CFs in test suite were named according to types being tested, e.g. `cf_map' what is not a valid CF name. Therefore, these names required "escaping" of invalid characters, here: simply replacing with '_'. --- tests/cql_query_test.cc | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/cql_query_test.cc b/tests/cql_query_test.cc index f4498c938d..fa167d11f2 100644 --- a/tests/cql_query_test.cc +++ b/tests/cql_query_test.cc @@ -1471,7 +1471,12 @@ struct aggregate_function_test { std::vector _sorted_values; sstring table_name() { - return "cf_" + _column_type->cql3_type_name(); + sstring tbl_name = "cf_" + _column_type->cql3_type_name(); + // Substitute troublesome characters from `cql3_type_name()': + std::for_each(tbl_name.begin(), tbl_name.end(), [] (char& c) { + if (c == '<' || c == '>' || c == ',' || c == ' ') { c = '_'; } + }); + return tbl_name; } void call_function_and_expect(const char* fname, data_type type, data_value expected) { auto msg = _e.execute_cql(format("select {}(value) from {}", fname, table_name())).get0(); @@ -1632,6 +1637,34 @@ SEASTAR_TEST_CASE(test_aggregate_functions) { net::inet_address("1::1"), net::inet_address("1.0.0.1") ).test_min_max_count(); + + auto list_type_int = list_type_impl::get_instance(int32_type, false); + aggregate_function_test(e, list_type_int, + make_list_value(list_type_int, {1, 2, 3}), + make_list_value(list_type_int, {1, 2, 4}), + make_list_value(list_type_int, {2, 2, 3}) + ).test_min_max_count(); + + auto set_type_int = set_type_impl::get_instance(int32_type, false); + aggregate_function_test(e, set_type_int, + make_set_value(set_type_int, {1, 2, 3}), + make_set_value(set_type_int, {1, 2, 4}), + make_set_value(set_type_int, {2, 3, 4}) + ).test_min_max_count(); + + auto tuple_type_int_text = tuple_type_impl::get_instance({int32_type, utf8_type}); + aggregate_function_test(e, tuple_type_int_text, + make_tuple_value(tuple_type_int_text, {1, "aaa"}), + make_tuple_value(tuple_type_int_text, {1, "bbb"}), + make_tuple_value(tuple_type_int_text, {2, "aaa"}) + ).test_min_max_count(); + + auto map_type_int_text = map_type_impl::get_instance(int32_type, utf8_type, false); + aggregate_function_test(e, map_type_int_text, + make_map_value(map_type_int_text, {std::make_pair(data_value(1), data_value("asdf"))}), + make_map_value(map_type_int_text, {std::make_pair(data_value(2), data_value("asdf"))}), + make_map_value(map_type_int_text, {std::make_pair(data_value(2), data_value("bsdf"))}) + ).test_min_max_count(); }); }