From 5b2b8d596ab63ac8f27ea1a792e2eb015a564162 Mon Sep 17 00:00:00 2001 From: Dawid Pawlik Date: Fri, 19 Dec 2025 15:15:06 +0100 Subject: [PATCH] vector_similarity_fcts: introduce similarity functions This patch introduces scalar functions `similarity_cosine()`, `similarity_euclidean()`, and `similarity_dot_product()` which should return a float - similarity of the given vectors calculated according to the function's similarity metric. The argument types of this function are retrieved with the `retrieve_vector_arg_types`, but shall be assignable to `vector` where `N` is the same for both arguments. This patch introduces a dimensionality check during the execusion of those functions. --- cql3/functions/functions.cc | 9 +++++++++ cql3/functions/vector_similarity_fcts.cc | 23 +++++++++++++++++++++++ cql3/functions/vector_similarity_fcts.hh | 17 +++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index 55d6010616..63c9df90a5 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -16,6 +16,7 @@ #include "cql3/functions/user_function.hh" #include "cql3/functions/user_aggregate.hh" #include "cql3/functions/uuid_fcts.hh" +#include "cql3/functions/vector_similarity_fcts.hh" #include "data_dictionary/data_dictionary.hh" #include "as_json_function.hh" #include "cql3/prepare_context.hh" @@ -398,6 +399,14 @@ functions::get(data_dictionary::database db, } }); + const auto func_name = name.has_keyspace() ? name : name.as_native_function(); + if (SIMILARITY_FUNCTIONS.contains(func_name)) { + auto arg_types = retrieve_vector_arg_types(func_name, provided_args); + auto fun = ::make_shared(func_name.name, arg_types); + validate_types(db, keyspace, schema.get(), fun, provided_args, receiver_ks, receiver_cf); + return fun; + } + if (name.has_keyspace() ? name == TOKEN_FUNCTION_NAME : name.name == TOKEN_FUNCTION_NAME.name) { diff --git a/cql3/functions/vector_similarity_fcts.cc b/cql3/functions/vector_similarity_fcts.cc index cdbb51f59d..ed74d906e5 100644 --- a/cql3/functions/vector_similarity_fcts.cc +++ b/cql3/functions/vector_similarity_fcts.cc @@ -82,6 +82,12 @@ float compute_dot_product_similarity(const std::vector& v1, const st } // namespace +thread_local const std::unordered_map SIMILARITY_FUNCTIONS = { + {SIMILARITY_COSINE_FUNCTION_NAME, compute_cosine_similarity}, + {SIMILARITY_EUCLIDEAN_FUNCTION_NAME, compute_euclidean_similarity}, + {SIMILARITY_DOT_PRODUCT_FUNCTION_NAME, compute_dot_product_similarity}, +}; + std::vector retrieve_vector_arg_types(const function_name& name, const std::vector>& provided_args) { if (provided_args.size() != 2) { throw exceptions::invalid_request_exception(fmt::format("Invalid number of arguments for function {}(vector, vector)", name)); @@ -123,5 +129,22 @@ std::vector retrieve_vector_arg_types(const function_name& name, cons return {type, type}; } +bytes_opt vector_similarity_fct::execute(std::span parameters) { + if (std::any_of(parameters.begin(), parameters.end(), [](const auto& param) { + return !param; + })) { + return std::nullopt; + } + + const auto& type = arg_types()[0]; + data_value v1 = type->deserialize(*parameters[0]); + data_value v2 = type->deserialize(*parameters[1]); + const auto& v1_elements = value_cast>(v1); + const auto& v2_elements = value_cast>(v2); + + float result = SIMILARITY_FUNCTIONS.at(_name)(v1_elements, v2_elements); + return float_type->decompose(result); +} + } // namespace functions } // namespace cql3 diff --git a/cql3/functions/vector_similarity_fcts.hh b/cql3/functions/vector_similarity_fcts.hh index b529fe2b56..18c1d7e38c 100644 --- a/cql3/functions/vector_similarity_fcts.hh +++ b/cql3/functions/vector_similarity_fcts.hh @@ -8,13 +8,30 @@ #pragma once +#include "native_scalar_function.hh" #include "cql3/assignment_testable.hh" #include "cql3/functions/function_name.hh" namespace cql3 { namespace functions { +static const function_name SIMILARITY_COSINE_FUNCTION_NAME = function_name::native_function("similarity_cosine"); +static const function_name SIMILARITY_EUCLIDEAN_FUNCTION_NAME = function_name::native_function("similarity_euclidean"); +static const function_name SIMILARITY_DOT_PRODUCT_FUNCTION_NAME = function_name::native_function("similarity_dot_product"); + +using similarity_function_t = float (*)(const std::vector&, const std::vector&); +extern thread_local const std::unordered_map SIMILARITY_FUNCTIONS; + std::vector retrieve_vector_arg_types(const function_name& name, const std::vector>& provided_args); +class vector_similarity_fct : public native_scalar_function { +public: + vector_similarity_fct(const sstring& name, const std::vector& arg_types) + : native_scalar_function(name, float_type, arg_types) { + } + + virtual bytes_opt execute(std::span parameters) override; +}; + } // namespace functions } // namespace cql3