mirror of
https://github.com/scylladb/scylladb.git
synced 2026-05-12 19:02:12 +00:00
udf: fix dropping UDFs that share names with other UDFs used in UDAs
Currently, when dropping a function, we only check if there exist
an aggregate that uses a function with the same name as its state
function or final function. This may cause the drop to fail even
when it's just another UDF with the same name that's used in the
aggregate, even when the actual dropped function is not used there.
This patch fixes this by checking whether not only the name of the
UDA's sfunc and finalfunc, but also their argument types.
(cherry picked from commit 49077dd144)
This commit is contained in:
@@ -150,10 +150,11 @@ void functions::remove_function(const function_name& name, const std::vector<dat
|
||||
with_udf_iter(name, arg_types, [] (functions::declared_t::iterator i) { _declared.erase(i); });
|
||||
}
|
||||
|
||||
std::optional<function_name> functions::used_by_user_aggregate(const function_name& name) {
|
||||
std::optional<function_name> functions::used_by_user_aggregate(const function_name& name, const std::vector<data_type>& arg_types) {
|
||||
for (const shared_ptr<function>& fptr : _declared | boost::adaptors::map_values) {
|
||||
auto aggregate = dynamic_pointer_cast<user_aggregate>(fptr);
|
||||
if (aggregate && (aggregate->sfunc().name() == name || (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name))) {
|
||||
if (aggregate && ((aggregate->sfunc().name() == name && aggregate->sfunc().arg_types() == arg_types)
|
||||
|| (aggregate->has_finalfunc() && aggregate->finalfunc().name() == name && aggregate->finalfunc().arg_types() == arg_types))) {
|
||||
return aggregate->name();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ public:
|
||||
static void add_function(shared_ptr<function>);
|
||||
static void replace_function(shared_ptr<function>);
|
||||
static void remove_function(const function_name& name, const std::vector<data_type>& arg_types);
|
||||
static std::optional<function_name> used_by_user_aggregate(const function_name& name);
|
||||
static std::optional<function_name> used_by_user_aggregate(const function_name& name, const std::vector<data_type>& arg_types);
|
||||
static std::optional<function_name> used_by_user_function(const ut_name& user_type);
|
||||
private:
|
||||
template <typename F>
|
||||
|
||||
@@ -35,7 +35,7 @@ drop_function_statement::prepare_schema_mutations(query_processor& qp, api::time
|
||||
if (!user_func) {
|
||||
throw exceptions::invalid_request_exception(format("'{}' is not a user defined function", func));
|
||||
}
|
||||
if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name()); bool(aggregate)) {
|
||||
if (auto aggregate = functions::functions::used_by_user_aggregate(user_func->name(), user_func->arg_types())) {
|
||||
throw exceptions::invalid_request_exception(format("Cannot delete function {}, as it is used by user-defined aggregate {}", func, *aggregate));
|
||||
}
|
||||
m = co_await qp.get_migration_manager().prepare_function_drop_announcement(user_func, ts);
|
||||
|
||||
@@ -71,13 +71,26 @@ def test_wrong_sfunc_or_ffunc(scylla_only, cql, test_keyspace):
|
||||
def test_drop_sfunc_or_ffunc(scylla_only, cql, test_keyspace):
|
||||
avg_partial_body = "(state tuple<bigint, bigint>, val bigint) CALLED ON NULL INPUT RETURNS tuple<bigint, bigint> LANGUAGE lua AS 'return {state[1] + val, state[2] + 1}'"
|
||||
div_body = "(state tuple<bigint, bigint>) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return state[1]//state[2]'"
|
||||
with new_function(cql, test_keyspace, avg_partial_body) as avg_partial, new_function(cql, test_keyspace, div_body) as div_fun:
|
||||
with new_function(cql, test_keyspace, avg_partial_body, args="tuple<bigint, bigint>, bigint") as avg_partial,\
|
||||
new_function(cql, test_keyspace, div_body, args="tuple<bigint, bigint>") as div_fun:
|
||||
custom_avg_body = f"(bigint) SFUNC {avg_partial} STYPE tuple<bigint, bigint> FINALFUNC {div_fun} INITCOND (0,0)"
|
||||
with new_aggregate(cql, test_keyspace, custom_avg_body) as custom_avg:
|
||||
with pytest.raises(InvalidRequest, match="it is used"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}")
|
||||
with pytest.raises(InvalidRequest, match="it is used"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}")
|
||||
avg_partial_body2 = "(state bigint, val bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'"
|
||||
div_body2 = "(state bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE lua AS 'return 42'"
|
||||
with new_function(cql, test_keyspace, avg_partial_body2, avg_partial, "bigint, bigint"),\
|
||||
new_function(cql, test_keyspace, div_body2, div_fun, "bigint"):
|
||||
with pytest.raises(InvalidRequest, match="There are multiple"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}")
|
||||
with pytest.raises(InvalidRequest, match="There are multiple"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}")
|
||||
with pytest.raises(InvalidRequest, match="it is used"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{avg_partial}(tuple<bigint, bigint>, bigint)")
|
||||
with pytest.raises(InvalidRequest, match="it is used"):
|
||||
cql.execute(f"DROP FUNCTION {test_keyspace}.{div_fun}(tuple<bigint, bigint>)")
|
||||
|
||||
# Test that the state function takes a correct number of arguments - the state and the new input
|
||||
def test_incorrect_state_func(scylla_only, cql, test_keyspace):
|
||||
|
||||
Reference in New Issue
Block a user