From 14d8cec130a99876aa417b30bb28e4402cfee423 Mon Sep 17 00:00:00 2001 From: Wojciech Mitros Date: Tue, 31 Jan 2023 12:03:11 +0100 Subject: [PATCH] udf: fix dropping UDFs that share names with other UDFs used in UDAs Currently, when dropping a function, we only check if there exist an aggregate that uses a function with the same name as its state function or final function. This may cause the drop to fail even when it's just another UDF with the same name that's used in the aggregate, even when the actual dropped function is not used there. This patch fixes this by checking whether not only the name of the UDA's sfunc and finalfunc, but also their argument types. (cherry picked from commit 49077dd1442218f0d8b3665ce80c35dc21e8c562) --- cql3/functions/functions.cc | 5 +++-- cql3/functions/functions.hh | 2 +- cql3/statements/drop_function_statement.cc | 2 +- test/cql-pytest/test_uda.py | 15 ++++++++++++++- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index 7bf5c639ee..a626d777bb 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -150,10 +150,11 @@ void functions::remove_function(const function_name& name, const std::vector functions::used_by_user_aggregate(const function_name& name) { +std::optional functions::used_by_user_aggregate(const function_name& name, const std::vector& arg_types) { for (const shared_ptr& fptr : _declared | boost::adaptors::map_values) { auto aggregate = dynamic_pointer_cast(fptr); - if (aggregate && (aggregate->sfunc().name() == name || (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name))) { + if (aggregate && ((aggregate->sfunc().name() == name && aggregate->sfunc().arg_types() == arg_types) + || (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name && aggregate->finalfunc().arg_types() == arg_types))) { return aggregate->name(); } } diff --git a/cql3/functions/functions.hh b/cql3/functions/functions.hh index adab343998..6231ce4546 100644 --- a/cql3/functions/functions.hh +++ b/cql3/functions/functions.hh @@ -71,7 +71,7 @@ public: static void add_function(shared_ptr); static void replace_function(shared_ptr); static void remove_function(const function_name& name, const std::vector& arg_types); - static std::optional used_by_user_aggregate(const function_name& name); + static std::optional used_by_user_aggregate(const function_name& name, const std::vector& arg_types); static std::optional used_by_user_function(const ut_name& user_type); private: template diff --git a/cql3/statements/drop_function_statement.cc b/cql3/statements/drop_function_statement.cc index 05898c9cb1..98134efe6d 100644 --- a/cql3/statements/drop_function_statement.cc +++ b/cql3/statements/drop_function_statement.cc @@ -35,7 +35,7 @@ drop_function_statement::prepare_schema_mutations(query_processor& qp, api::time if (!user_func) { throw exceptions::invalid_request_exception(format("'{}' is not a user defined function", func)); } - if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name()); bool(aggregate)) { + if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name(), user_func->arg_types())) { throw exceptions::invalid_request_exception(format("Cannot delete function {}, as it is used by user-defined aggregate {}", func, *aggregate)); } m = co_await qp.get_migration_manager().prepare_function_drop_announcement(user_func, ts); diff --git a/test/cql-pytest/test_uda.py b/test/cql-pytest/test_uda.py index c405ed4162..310f44344f 100644 --- a/test/cql-pytest/test_uda.py +++ b/test/cql-pytest/test_uda.py @@ -71,13 +71,26 @@ def test_wrong_sfunc_or_ffunc(scylla_only, cql, test_keyspace): def test_drop_sfunc_or_ffunc(scylla_only, cql, test_keyspace): avg_partial_body = "(state tuple, val bigint) CALLED ON NULL INPUT RETURNS tuple LANGUAGE lua AS 'return {state[1] + val, state[2] + 1}'" div_body = "(state tuple) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return state[1]//state[2]'" - with new_function(cql, test_keyspace, avg_partial_body) as avg_partial, new_function(cql, test_keyspace, div_body) as div_fun: + with new_function(cql, test_keyspace, avg_partial_body, args="tuple, bigint") as avg_partial,\ + new_function(cql, test_keyspace, div_body, args="tuple") as div_fun: custom_avg_body = f"(bigint) SFUNC {avg_partial} STYPE tuple FINALFUNC {div_fun} INITCOND (0,0)" with new_aggregate(cql, test_keyspace, custom_avg_body) as custom_avg: with pytest.raises(InvalidRequest, match="it is used"): cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}") with pytest.raises(InvalidRequest, match="it is used"): cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}") + avg_partial_body2 = "(state bigint, val bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'" + div_body2 = "(state bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'" + with new_function(cql, test_keyspace, avg_partial_body2, avg_partial, "bigint, bigint"),\ + new_function(cql, test_keyspace, div_body2, div_fun, "bigint"): + with pytest.raises(InvalidRequest, match="There are multiple"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}") + with pytest.raises(InvalidRequest, match="There are multiple"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}") + with pytest.raises(InvalidRequest, match="it is used"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}(tuple, bigint)") + with pytest.raises(InvalidRequest, match="it is used"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}(tuple)") # Test that the state function takes a correct number of arguments - the state and the new input def test_incorrect_state_func(scylla_only, cql, test_keyspace):