diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index 7bf5c639ee..a626d777bb 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -150,10 +150,11 @@ void functions::remove_function(const function_name& name, const std::vector functions::used_by_user_aggregate(const function_name& name) { +std::optional functions::used_by_user_aggregate(const function_name& name, const std::vector& arg_types) { for (const shared_ptr& fptr : _declared | boost::adaptors::map_values) { auto aggregate = dynamic_pointer_cast(fptr); - if (aggregate && (aggregate->sfunc().name() == name || (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name))) { + if (aggregate && ((aggregate->sfunc().name() == name && aggregate->sfunc().arg_types() == arg_types) + || (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name && aggregate->finalfunc().arg_types() == arg_types))) { return aggregate->name(); } } diff --git a/cql3/functions/functions.hh b/cql3/functions/functions.hh index adab343998..6231ce4546 100644 --- a/cql3/functions/functions.hh +++ b/cql3/functions/functions.hh @@ -71,7 +71,7 @@ public: static void add_function(shared_ptr); static void replace_function(shared_ptr); static void remove_function(const function_name& name, const std::vector& arg_types); - static std::optional used_by_user_aggregate(const function_name& name); + static std::optional used_by_user_aggregate(const function_name& name, const std::vector& arg_types); static std::optional used_by_user_function(const ut_name& user_type); private: template diff --git a/cql3/statements/drop_function_statement.cc b/cql3/statements/drop_function_statement.cc index 05898c9cb1..98134efe6d 100644 --- a/cql3/statements/drop_function_statement.cc +++ b/cql3/statements/drop_function_statement.cc @@ -35,7 +35,7 @@ drop_function_statement::prepare_schema_mutations(query_processor& qp, api::time if (!user_func) { throw exceptions::invalid_request_exception(format("'{}' is not a user defined function", func)); } - if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name()); bool(aggregate)) { + if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name(), user_func->arg_types())) { throw exceptions::invalid_request_exception(format("Cannot delete function {}, as it is used by user-defined aggregate {}", func, *aggregate)); } m = co_await qp.get_migration_manager().prepare_function_drop_announcement(user_func, ts); diff --git a/test/cql-pytest/test_uda.py b/test/cql-pytest/test_uda.py index c405ed4162..310f44344f 100644 --- a/test/cql-pytest/test_uda.py +++ b/test/cql-pytest/test_uda.py @@ -71,13 +71,26 @@ def test_wrong_sfunc_or_ffunc(scylla_only, cql, test_keyspace): def test_drop_sfunc_or_ffunc(scylla_only, cql, test_keyspace): avg_partial_body = "(state tuple, val bigint) CALLED ON NULL INPUT RETURNS tuple LANGUAGE lua AS 'return {state[1] + val, state[2] + 1}'" div_body = "(state tuple) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return state[1]//state[2]'" - with new_function(cql, test_keyspace, avg_partial_body) as avg_partial, new_function(cql, test_keyspace, div_body) as div_fun: + with new_function(cql, test_keyspace, avg_partial_body, args="tuple, bigint") as avg_partial,\ + new_function(cql, test_keyspace, div_body, args="tuple") as div_fun: custom_avg_body = f"(bigint) SFUNC {avg_partial} STYPE tuple FINALFUNC {div_fun} INITCOND (0,0)" with new_aggregate(cql, test_keyspace, custom_avg_body) as custom_avg: with pytest.raises(InvalidRequest, match="it is used"): cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}") with pytest.raises(InvalidRequest, match="it is used"): cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}") + avg_partial_body2 = "(state bigint, val bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'" + div_body2 = "(state bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'" + with new_function(cql, test_keyspace, avg_partial_body2, avg_partial, "bigint, bigint"),\ + new_function(cql, test_keyspace, div_body2, div_fun, "bigint"): + with pytest.raises(InvalidRequest, match="There are multiple"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}") + with pytest.raises(InvalidRequest, match="There are multiple"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}") + with pytest.raises(InvalidRequest, match="it is used"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}(tuple, bigint)") + with pytest.raises(InvalidRequest, match="it is used"): + cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}(tuple)") # Test that the state function takes a correct number of arguments - the state and the new input def test_incorrect_state_func(scylla_only, cql, test_keyspace):