vector_similarity_fcts: introduce similarity functions

This patch introduces scalar functions `similarity_cosine()`,
`similarity_euclidean()`, and `similarity_dot_product()`
which should return a float - similarity of the given vectors
calculated according to the function's similarity metric.

The argument types of this function are retrieved with
the `retrieve_vector_arg_types`, but shall be assignable to
`vector<float, N>` where `N` is the same for both arguments.

This patch introduces a dimensionality check during the execusion
of those functions.
This commit is contained in:
Dawid Pawlik
2025-12-19 15:15:06 +01:00
parent b72df3ae27
commit 5b2b8d596a
3 changed files with 49 additions and 0 deletions

View File

@@ -16,6 +16,7 @@
#include "cql3/functions/user_function.hh"
#include "cql3/functions/user_aggregate.hh"
#include "cql3/functions/uuid_fcts.hh"
#include "cql3/functions/vector_similarity_fcts.hh"
#include "data_dictionary/data_dictionary.hh"
#include "as_json_function.hh"
#include "cql3/prepare_context.hh"
@@ -398,6 +399,14 @@ functions::get(data_dictionary::database db,
}
});
const auto func_name = name.has_keyspace() ? name : name.as_native_function();
if (SIMILARITY_FUNCTIONS.contains(func_name)) {
auto arg_types = retrieve_vector_arg_types(func_name, provided_args);
auto fun = ::make_shared<vector_similarity_fct>(func_name.name, arg_types);
validate_types(db, keyspace, schema.get(), fun, provided_args, receiver_ks, receiver_cf);
return fun;
}
if (name.has_keyspace()
? name == TOKEN_FUNCTION_NAME
: name.name == TOKEN_FUNCTION_NAME.name) {

View File

@@ -82,6 +82,12 @@ float compute_dot_product_similarity(const std::vector<data_value>& v1, const st
} // namespace
thread_local const std::unordered_map<function_name, similarity_function_t> SIMILARITY_FUNCTIONS = {
{SIMILARITY_COSINE_FUNCTION_NAME, compute_cosine_similarity},
{SIMILARITY_EUCLIDEAN_FUNCTION_NAME, compute_euclidean_similarity},
{SIMILARITY_DOT_PRODUCT_FUNCTION_NAME, compute_dot_product_similarity},
};
std::vector<data_type> retrieve_vector_arg_types(const function_name& name, const std::vector<shared_ptr<assignment_testable>>& provided_args) {
if (provided_args.size() != 2) {
throw exceptions::invalid_request_exception(fmt::format("Invalid number of arguments for function {}(vector<float, n>, vector<float, n>)", name));
@@ -123,5 +129,22 @@ std::vector<data_type> retrieve_vector_arg_types(const function_name& name, cons
return {type, type};
}
bytes_opt vector_similarity_fct::execute(std::span<const bytes_opt> parameters) {
if (std::any_of(parameters.begin(), parameters.end(), [](const auto& param) {
return !param;
})) {
return std::nullopt;
}
const auto& type = arg_types()[0];
data_value v1 = type->deserialize(*parameters[0]);
data_value v2 = type->deserialize(*parameters[1]);
const auto& v1_elements = value_cast<std::vector<data_value>>(v1);
const auto& v2_elements = value_cast<std::vector<data_value>>(v2);
float result = SIMILARITY_FUNCTIONS.at(_name)(v1_elements, v2_elements);
return float_type->decompose(result);
}
} // namespace functions
} // namespace cql3

View File

@@ -8,13 +8,30 @@
#pragma once
#include "native_scalar_function.hh"
#include "cql3/assignment_testable.hh"
#include "cql3/functions/function_name.hh"
namespace cql3 {
namespace functions {
static const function_name SIMILARITY_COSINE_FUNCTION_NAME = function_name::native_function("similarity_cosine");
static const function_name SIMILARITY_EUCLIDEAN_FUNCTION_NAME = function_name::native_function("similarity_euclidean");
static const function_name SIMILARITY_DOT_PRODUCT_FUNCTION_NAME = function_name::native_function("similarity_dot_product");
using similarity_function_t = float (*)(const std::vector<data_value>&, const std::vector<data_value>&);
extern thread_local const std::unordered_map<function_name, similarity_function_t> SIMILARITY_FUNCTIONS;
std::vector<data_type> retrieve_vector_arg_types(const function_name& name, const std::vector<shared_ptr<assignment_testable>>& provided_args);
class vector_similarity_fct : public native_scalar_function {
public:
vector_similarity_fct(const sstring& name, const std::vector<data_type>& arg_types)
: native_scalar_function(name, float_type, arg_types) {
}
virtual bytes_opt execute(std::span<const bytes_opt> parameters) override;
};
} // namespace functions
} // namespace cql3