From aa17c7739ee6a255ca5b863b32d99b905900535b Mon Sep 17 00:00:00 2001 From: Szymon Malewski Date: Fri, 6 Mar 2026 02:08:57 +0100 Subject: [PATCH 1/3] vector_index: split query_base_table to return raw coordinator_result The inner query_base_table overloads previously called process_results() themselves, duplicating row_limit setup and making it impossible to thread per-execution context (e.g. a similarity provider) into result processing. Lift process_results() to the top-level overload and change the two inner overloads to return coordinator_result> directly. This cleanly separates query dispatch from result processing, and opens the door to passing execution-time context at the single process_results() call site. No functional change. --- cql3/statements/select_statement.cc | 43 +++++++++++++++-------------- cql3/statements/select_statement.hh | 4 +-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index c262395f64..3f379072ef 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -2171,7 +2171,7 @@ std::vector vector_indexed_table_select_statement::get_ann_ordering_vecto if (expr_value.is_null()) { throw exceptions::invalid_request_exception(fmt::format("Unsupported null value for column {}", _prepared_ann_ordering.first->name_as_text())); } - auto values = value_cast(ann_column->type->deserialize(expr::evaluate(ann_vector_expr, options).to_bytes())); + auto values = value_cast(ann_column->type->deserialize(std::move(expr_value).to_bytes())); return util::to_vector(values); } @@ -2180,6 +2180,19 @@ future<::shared_ptr> vector_indexed_tab lowres_clock::time_point timeout) const { auto command = prepare_command_for_base_query(qp, state, options, pkeys.size()); + auto result = co_await query_base_table(qp, state, options, command, timeout, pkeys); + + command->set_row_limit(get_limit(options, _limit)); + + co_return co_await wrap_result_to_error_message([this, command = std::move(command), &options](auto query_result) { + return process_results(std::move(query_result), command, options, _query_start_time_point); + })(std::move(result)); +} + +future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, + service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, + const std::vector& pkeys) const { + // For tables without clustering columns, we can optimize by querying // partition ranges instead of individual primary keys, since the // partition key alone uniquely identifies each row. @@ -2194,14 +2207,7 @@ future<::shared_ptr> vector_indexed_tab }; co_return co_await query_base_table(qp, state, options, std::move(command), timeout, to_partition_ranges(pkeys)); } - co_return co_await query_base_table(qp, state, options, std::move(command), timeout, pkeys); -} - -future<::shared_ptr> vector_indexed_table_select_statement::query_base_table(query_processor& qp, - service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, - const std::vector& pkeys) const { - - coordinator_result>> result = co_await utils::result_map_reduce( + co_return co_await utils::result_map_reduce( pkeys.begin(), pkeys.end(), [&](this auto, auto& key) -> future>>> { auto cmd = ::make_lw_shared(*command); @@ -2215,25 +2221,20 @@ future<::shared_ptr> vector_indexed_tab co_return std::move(rqr.value().query_result); }, query::result_merger{command->get_row_limit(), query::max_partitions}); - - co_return co_await wrap_result_to_error_message([this, &command, &options](auto result) { - command->set_row_limit(get_limit(options, _limit)); - return process_results(std::move(result), command, options, _query_start_time_point); - })(std::move(result)); } -future<::shared_ptr> vector_indexed_table_select_statement::query_base_table(query_processor& qp, +future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, std::vector partition_ranges) const { - co_return co_await qp.proxy() + coordinator_result rqr = co_await qp.proxy() .query_result(_query_schema, command, std::move(partition_ranges), options.get_consistency(), {timeout, state.get_permit(), state.get_client_state(), state.get_trace_state(), {}, {}, options.get_specific_options().node_local_only}, - std::nullopt) - .then(wrap_result_to_error_message([this, &options, command](service::storage_proxy::coordinator_query_result qr) { - command->set_row_limit(get_limit(options, _limit)); - return this->process_results(std::move(qr.query_result), command, options, _query_start_time_point); - })); + std::nullopt); + if (!rqr) { + co_return std::move(rqr).as_failure(); + } + co_return std::move(rqr.value().query_result); } namespace raw { diff --git a/cql3/statements/select_statement.hh b/cql3/statements/select_statement.hh index 87ce5accd9..262e2510eb 100644 --- a/cql3/statements/select_statement.hh +++ b/cql3/statements/select_statement.hh @@ -395,11 +395,11 @@ private: future<::shared_ptr> query_base_table(query_processor& qp, service::query_state& state, const query_options& options, const std::vector& pkeys, lowres_clock::time_point timeout) const; - future<::shared_ptr> query_base_table(query_processor& qp, service::query_state& state, + future>>> query_base_table(query_processor& qp, service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, const std::vector& pkeys) const; - future<::shared_ptr> query_base_table(query_processor& qp, service::query_state& state, + future>>> query_base_table(query_processor& qp, service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, std::vector partition_ranges) const; }; From 5e94abe3bc7d56b98c1d2d85172f82eabced84e9 Mon Sep 17 00:00:00 2001 From: Szymon Malewski Date: Wed, 15 Apr 2026 18:22:43 +0200 Subject: [PATCH 2/3] cql3: extract vector_indexed_table_select_statement into own compilation unit Move vector_indexed_table_select_statement and its associated helpers (ann_ordering_info, get_ann_ordering_info, add_similarity_function_to_selectors, get_similarity_ordering_comparator) from select_statement.hh/.cc into new files cql3/statements/external_search/vector_indexed_table_select_statement.hh/.cc. --- configure.py | 1 + cql3/CMakeLists.txt | 1 + .../vector_indexed_table_select_statement.cc | 363 ++++++++++++++++++ .../vector_indexed_table_select_statement.hh | 90 +++++ cql3/statements/select_statement.cc | 280 +------------- cql3/statements/select_statement.hh | 45 --- 6 files changed, 456 insertions(+), 324 deletions(-) create mode 100644 cql3/statements/external_search/vector_indexed_table_select_statement.cc create mode 100644 cql3/statements/external_search/vector_indexed_table_select_statement.hh diff --git a/configure.py b/configure.py index fd3951c0bb..7c4145b8f9 100755 --- a/configure.py +++ b/configure.py @@ -1061,6 +1061,7 @@ scylla_core = (['message/messaging_service.cc', 'cql3/statements/prune_materialized_view_statement.cc', 'cql3/statements/batch_statement.cc', 'cql3/statements/select_statement.cc', + 'cql3/statements/external_search/vector_indexed_table_select_statement.cc', 'cql3/statements/use_statement.cc', 'cql3/statements/index_prop_defs.cc', 'cql3/statements/index_target.cc', diff --git a/cql3/CMakeLists.txt b/cql3/CMakeLists.txt index 20bf0bbe21..3d4c3f7c4e 100644 --- a/cql3/CMakeLists.txt +++ b/cql3/CMakeLists.txt @@ -81,6 +81,7 @@ target_sources(cql3 statements/prune_materialized_view_statement.cc statements/batch_statement.cc statements/select_statement.cc + statements/external_search/vector_indexed_table_select_statement.cc statements/use_statement.cc statements/index_prop_defs.cc statements/index_target.cc diff --git a/cql3/statements/external_search/vector_indexed_table_select_statement.cc b/cql3/statements/external_search/vector_indexed_table_select_statement.cc new file mode 100644 index 0000000000..469d653897 --- /dev/null +++ b/cql3/statements/external_search/vector_indexed_table_select_statement.cc @@ -0,0 +1,363 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1 + */ + +#include "cql3/statements/external_search/vector_indexed_table_select_statement.hh" + +#include "cql3/expr/evaluate.hh" +#include "cql3/expr/expr-utils.hh" +#include "cql3/functions/functions.hh" +#include "cql3/statements/raw/select_statement.hh" +#include "cql3/query_processor.hh" +#include "cql3/util.hh" + +#include "db/consistency_level_validations.hh" +#include "replica/database.hh" +#include "exceptions/exceptions.hh" +#include "index/vector_index.hh" +#include "query/query_result_merger.hh" +#include "service/storage_proxy.hh" +#include "types/vector.hh" +#include "utils/result_loop.hh" + +#include +#include +#include + + +template +using coordinator_result = cql3::statements::select_statement::coordinator_result; + +namespace cql3 { + +namespace statements { + +static logging::logger logger("vector_indexed_table_select_statement"); + +namespace { + +template +auto measure_index_latency(const schema& schema, const secondary_index::index& index, Func&& func) -> std::invoke_result_t { + auto start_time = lowres_system_clock::now(); + auto result = co_await func(); + auto duration = lowres_system_clock::now() - start_time; + + auto stats = schema.table().get_index_manager().get_index_stats(index.metadata().name()); + if (stats) { + stats->add_latency(duration); + } + + co_return result; +} + +template +struct result_to_error_message_wrapper { + C c; + + template + auto operator()(coordinator_result&& arg) { + if constexpr (std::is_void_v) { + if (arg) { + return futurize_invoke(c); + } else { + return make_ready_future>::value_type>( + ::make_shared(std::move(arg).assume_error()) + ); + } + } else { + if (arg) { + return futurize_invoke(c, std::move(arg).value()); + } else { + return make_ready_future>::value_type>( + ::make_shared(std::move(arg).assume_error()) + ); + } + } + } +}; + +template +auto wrap_result_to_error_message(C&& c) { + return result_to_error_message_wrapper{std::move(c)}; +} + +} // anonymous namespace + +std::optional get_ann_ordering_info( + data_dictionary::database db, + schema_ptr schema, + lw_shared_ptr parameters, + prepare_context& ctx) { + + if (parameters->orderings().empty()) { + return std::nullopt; + } + + auto [column_id, ordering] = parameters->orderings().front(); + const auto& ann_vector = std::get_if(&ordering); + if (!ann_vector) { + return std::nullopt; + } + + ::shared_ptr column = column_id->prepare_column_identifier(*schema); + const column_definition* def = schema->get_column_definition(column->name()); + if (!def) { + throw exceptions::invalid_request_exception( + fmt::format("Undefined column name {}", column->text())); + } + + if (!def->type->is_vector() || static_cast(def->type.get())->get_elements_type()->get_kind() != abstract_type::kind::float_kind) { + throw exceptions::invalid_request_exception("ANN ordering is only supported on float vector indexes"); + } + + auto e = expr::prepare_expression(*ann_vector, db, schema->ks_name(), nullptr, def->column_specification); + expr::fill_prepare_context(e, ctx); + + raw::select_statement::prepared_ann_ordering_type prepared_ann_ordering = std::make_pair(std::move(def), std::move(e)); + + auto cf = db.find_column_family(schema); + auto& sim = cf.get_index_manager(); + + auto indexes = sim.list_indexes(); + auto it = std::find_if(indexes.begin(), indexes.end(), [&prepared_ann_ordering](const auto& ind) { + return secondary_index::vector_index::is_vector_index_on_column(ind.metadata(), prepared_ann_ordering.first->name_as_text()); + }); + + if (it == indexes.end()) { + throw exceptions::invalid_request_exception("ANN ordering by vector requires the column to be indexed using 'vector_index'"); + } + + return ann_ordering_info{ + *it, + std::move(prepared_ann_ordering), + secondary_index::vector_index::is_rescoring_enabled(it->metadata().options()) + }; +} + +uint32_t add_similarity_function_to_selectors( + std::vector& prepared_selectors, + const ann_ordering_info& ann_ordering_info, + data_dictionary::database db, + schema_ptr schema) { + auto similarity_function_name = secondary_index::vector_index::get_cql_similarity_function_name(ann_ordering_info._index.metadata().options()); + // Create the function name + auto func_name = functions::function_name::native_function(sstring(similarity_function_name)); + + // Create the function arguments + std::vector args; + args.push_back(expr::column_value(ann_ordering_info._prepared_ann_ordering.first)); + args.push_back(ann_ordering_info._prepared_ann_ordering.second); + + // Get the function object + std::vector> provided_args; + provided_args.push_back(expr::as_assignment_testable(args[0], expr::type_of(args[0]))); + provided_args.push_back(expr::as_assignment_testable(args[1], expr::type_of(args[1]))); + + auto func = cql3::functions::instance().get(db, schema->ks_name(), func_name, provided_args, schema->ks_name(), schema->cf_name(), nullptr); + + // Create the function call expression + expr::function_call similarity_func_call{ + .func = func, + .args = std::move(args), + }; + + // Add the similarity function as a prepared selector (last) + prepared_selectors.push_back(selection::prepared_selector{ + .expr = std::move(similarity_func_call), + .alias = nullptr, + }); + return prepared_selectors.size() - 1; +} + +select_statement::ordering_comparator_type get_similarity_ordering_comparator(std::vector& prepared_selectors, uint32_t similarity_column_index) { + auto type = expr::type_of(prepared_selectors[similarity_column_index].expr); + if (type->get_kind() != abstract_type::kind::float_kind) { + seastar::on_internal_error(logger, "Similarity function must return float type."); + } + return [similarity_column_index, type] (const raw::select_statement::result_row_type& r1, const raw::select_statement::result_row_type& r2) { + auto& c1 = r1[similarity_column_index]; + auto& c2 = r2[similarity_column_index]; + auto f1 = c1 ? value_cast(type->deserialize(*c1)) : std::numeric_limits::quiet_NaN(); + auto f2 = c2 ? value_cast(type->deserialize(*c2)) : std::numeric_limits::quiet_NaN(); + if (std::isfinite(f1) && std::isfinite(f2)) { + return f1 > f2; + } + return std::isfinite(f1); + }; +} + +::shared_ptr vector_indexed_table_select_statement::prepare(data_dictionary::database db, schema_ptr schema, + uint32_t bound_terms, lw_shared_ptr parameters, ::shared_ptr selection, + ::shared_ptr restrictions, ::shared_ptr> group_by_cell_indices, bool is_reversed, + ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, + std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, std::unique_ptr attrs) { + + auto prepared_filter = vector_search::prepare_filter(*restrictions, parameters->allow_filtering()); + + return ::make_shared(schema, bound_terms, parameters, std::move(selection), std::move(restrictions), + std::move(group_by_cell_indices), is_reversed, std::move(ordering_comparator), std::move(prepared_ann_ordering), std::move(limit), + std::move(per_partition_limit), stats, index, std::move(prepared_filter), std::move(attrs)); +} + +vector_indexed_table_select_statement::vector_indexed_table_select_statement(schema_ptr schema, uint32_t bound_terms, lw_shared_ptr parameters, + ::shared_ptr selection, ::shared_ptr restrictions, + ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, + prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, + std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, + vector_search::prepared_filter prepared_filter, std::unique_ptr attrs) + : select_statement{schema, bound_terms, parameters, selection, restrictions, group_by_cell_indices, is_reversed, ordering_comparator, limit, + per_partition_limit, stats, std::move(attrs)} + , _index{index} + , _prepared_ann_ordering(std::move(prepared_ann_ordering)) + , _prepared_filter(std::move(prepared_filter)) { + + if (!limit.has_value()) { + throw exceptions::invalid_request_exception("Vector ANN queries must have a limit specified"); + } + + if (per_partition_limit.has_value()) { + throw exceptions::invalid_request_exception("Vector ANN queries do not support per-partition limits"); + } + + if (selection->is_aggregate()) { + throw exceptions::invalid_request_exception("Vector ANN queries cannot be run with aggregation"); + } +} + +future> vector_indexed_table_select_statement::do_execute( + query_processor& qp, service::query_state& state, const query_options& options) const { + + auto limit = get_limit(options, _limit); + + auto result = co_await measure_index_latency(*_schema, _index, [this, &qp, &state, &options, &limit](this auto) -> future> { + tracing::add_table_name(state.get_trace_state(), keyspace(), column_family()); + validate_for_read(options.get_consistency()); + + _query_start_time_point = gc_clock::now(); + + update_stats(); + + if (limit > max_ann_query_limit) { + co_await coroutine::return_exception(exceptions::invalid_request_exception( + fmt::format("Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than {}. LIMIT was {}", max_ann_query_limit, limit))); + } + + auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options); + auto aoe = abort_on_expiry(timeout); + auto filter_json = _prepared_filter.to_json(options); + uint64_t fetch = static_cast(std::ceil(limit * secondary_index::vector_index::get_oversampling(_index.metadata().options()))); + auto pkeys = co_await qp.vector_store_client().ann( + _schema->ks_name(), _index.metadata().name(), _schema, get_ann_ordering_vector(options), fetch, filter_json, aoe.abort_source()); + if (!pkeys.has_value()) { + co_await coroutine::return_exception( + exceptions::invalid_request_exception(std::visit(vector_search::vector_store_client::ann_error_visitor{}, pkeys.error()))); + } + + if (pkeys->size() > limit && !secondary_index::vector_index::is_rescoring_enabled(_index.metadata().options())) { + pkeys->erase(pkeys->begin() + limit, pkeys->end()); + } + + co_return co_await query_base_table(qp, state, options, pkeys.value(), timeout); + }); + + auto page_size = options.get_page_size(); + if (page_size > 0 && (uint64_t) page_size < limit) { + result->add_warning("Paging is not supported for Vector Search queries. The entire result set has been returned."); + } + co_return result; +} + +void vector_indexed_table_select_statement::update_stats() const { + ++_stats.secondary_index_reads; + ++_stats.query_cnt(source_selector::USER, _ks_sel, cond_selector::NO_CONDITIONS, statement_type::SELECT); +} + +lw_shared_ptr vector_indexed_table_select_statement::prepare_command_for_base_query( + query_processor& qp, service::query_state& state, const query_options& options, uint64_t fetch_limit) const { + auto slice = make_partition_slice(options); + return ::make_lw_shared(_schema->id(), _schema->version(), std::move(slice), qp.proxy().get_max_result_size(slice), + query::tombstone_limit(qp.proxy().get_tombstone_limit()), + query::row_limit(get_inner_loop_limit(fetch_limit, _selection->is_aggregate())), query::partition_limit(query::max_partitions), + _query_start_time_point, tracing::make_trace_info(state.get_trace_state()), query_id::create_null_id(), query::is_first_page::no, + options.get_timestamp(state)); +} + +std::vector vector_indexed_table_select_statement::get_ann_ordering_vector(const query_options& options) const { + auto [ann_column, ann_vector_expr] = _prepared_ann_ordering; + auto expr_value = expr::evaluate(ann_vector_expr, options); + if (expr_value.is_null()) { + throw exceptions::invalid_request_exception(fmt::format("Unsupported null value for column {}", _prepared_ann_ordering.first->name_as_text())); + } + auto values = value_cast(ann_column->type->deserialize(std::move(expr_value).to_bytes())); + return util::to_vector(values); +} + +future<::shared_ptr> vector_indexed_table_select_statement::query_base_table(query_processor& qp, + service::query_state& state, const query_options& options, const std::vector& pkeys, + lowres_clock::time_point timeout) const { + auto command = prepare_command_for_base_query(qp, state, options, pkeys.size()); + + auto result = co_await query_base_table(qp, state, options, command, timeout, pkeys); + + command->set_row_limit(get_limit(options, _limit)); + + co_return co_await wrap_result_to_error_message([this, command = std::move(command), &options](auto query_result) { + return process_results(std::move(query_result), command, options, _query_start_time_point); + })(std::move(result)); +} + +future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, + service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, + const std::vector& pkeys) const { + + // For tables without clustering columns, we can optimize by querying + // partition ranges instead of individual primary keys, since the + // partition key alone uniquely identifies each row. + if (_schema->clustering_key_size() == 0) { + auto to_partition_ranges = [](const std::vector& pkeys) -> std::vector { + std::vector partition_ranges; + std::ranges::transform(pkeys, std::back_inserter(partition_ranges), [](const auto& pkey) { + return dht::partition_range::make_singular(pkey.partition); + }); + + return partition_ranges; + }; + co_return co_await query_base_table(qp, state, options, std::move(command), timeout, to_partition_ranges(pkeys)); + } + co_return co_await utils::result_map_reduce( + pkeys.begin(), pkeys.end(), + [&](this auto, auto& key) -> future>>> { + auto cmd = ::make_lw_shared(*command); + cmd->slice._row_ranges = query::clustering_row_ranges{query::clustering_range::make_singular(key.clustering)}; + coordinator_result rqr = + co_await qp.proxy().query_result(_schema, cmd, {dht::partition_range::make_singular(key.partition)}, options.get_consistency(), + {timeout, state.get_permit(), state.get_client_state(), state.get_trace_state()}); + if (!rqr) { + co_return std::move(rqr).as_failure(); + } + co_return std::move(rqr.value().query_result); + }, + query::result_merger{command->get_row_limit(), query::max_partitions}); +} + +future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, + service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, + std::vector partition_ranges) const { + + coordinator_result rqr = co_await qp.proxy() + .query_result(_query_schema, command, std::move(partition_ranges), options.get_consistency(), + {timeout, state.get_permit(), state.get_client_state(), state.get_trace_state(), {}, {}, options.get_specific_options().node_local_only}, + std::nullopt); + if (!rqr) { + co_return std::move(rqr).as_failure(); + } + co_return std::move(rqr.value().query_result); +} + +} // namespace statements + +} // namespace cql3 diff --git a/cql3/statements/external_search/vector_indexed_table_select_statement.hh b/cql3/statements/external_search/vector_indexed_table_select_statement.hh new file mode 100644 index 0000000000..cd6239f7d0 --- /dev/null +++ b/cql3/statements/external_search/vector_indexed_table_select_statement.hh @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1 + */ + +#pragma once + +#include "cql3/statements/select_statement.hh" +#include "vector_search/vector_store_client.hh" +#include "vector_search/filter.hh" + +#include + +namespace cql3::statements { + +/// ANN ordering metadata resolved during prepare. +struct ann_ordering_info { + secondary_index::index _index; + raw::select_statement::prepared_ann_ordering_type _prepared_ann_ordering; + bool is_rescoring_enabled; +}; + +/// Resolves ANN ordering metadata from the query's ORDER BY clause. +/// Returns std::nullopt if the query is not an ANN query. +std::optional get_ann_ordering_info( + data_dictionary::database db, + schema_ptr schema, + lw_shared_ptr parameters, + prepare_context& ctx); + +/// Adds a similarity function call to prepared_selectors based on the ANN index. +/// Returns the index of the appended selector within prepared_selectors. +uint32_t add_similarity_function_to_selectors( + std::vector& prepared_selectors, + const ann_ordering_info& ann_ordering_info, + data_dictionary::database db, + schema_ptr schema); + +/// Builds an ordering comparator that sorts by descending similarity score. +select_statement::ordering_comparator_type get_similarity_ordering_comparator( + std::vector& prepared_selectors, + uint32_t similarity_column_index); + +class vector_indexed_table_select_statement : public select_statement { + secondary_index::index _index; + prepared_ann_ordering_type _prepared_ann_ordering; + vector_search::prepared_filter _prepared_filter; + mutable gc_clock::time_point _query_start_time_point; + +public: + static constexpr size_t max_ann_query_limit = 1000; + + static ::shared_ptr prepare(data_dictionary::database db, schema_ptr schema, uint32_t bound_terms, + lw_shared_ptr parameters, ::shared_ptr selection, + ::shared_ptr restrictions, ::shared_ptr> group_by_cell_indices, bool is_reversed, + ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, + std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, std::unique_ptr attrs); + + vector_indexed_table_select_statement(schema_ptr schema, uint32_t bound_terms, lw_shared_ptr parameters, + ::shared_ptr selection, ::shared_ptr restrictions, + ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, + prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, std::optional per_partition_limit, + cql_stats& stats, const secondary_index::index& index, vector_search::prepared_filter prepared_filter, std::unique_ptr attrs); + +private: + future<::shared_ptr> do_execute( + query_processor& qp, service::query_state& state, const query_options& options) const override; + + void update_stats() const; + + lw_shared_ptr prepare_command_for_base_query(query_processor& qp, service::query_state& state, const query_options& options, uint64_t fetch_limit) const; + + std::vector get_ann_ordering_vector(const query_options& options) const; + + future<::shared_ptr> query_base_table(query_processor& qp, service::query_state& state, + const query_options& options, const std::vector& pkeys, lowres_clock::time_point timeout) const; + + future>>> query_base_table(query_processor& qp, service::query_state& state, + const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, + const std::vector& pkeys) const; + + future>>> query_base_table(query_processor& qp, service::query_state& state, + const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, + std::vector partition_ranges) const; +}; + +} // namespace cql3::statements diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index 3f379072ef..2968ce191c 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -11,6 +11,7 @@ #include "cql3/statements/strong_consistency/select_statement.hh" #include "cql3/statements/strong_consistency/statement_helpers.hh" #include "cql3/statements/select_statement.hh" +#include "cql3/statements/external_search/vector_indexed_table_select_statement.hh" #include "cql3/expr/expression.hh" #include "cql3/expr/evaluate.hh" #include "cql3/expr/expr-utils.hh" @@ -33,8 +34,6 @@ #include "cql3/util.hh" #include "cql3/restrictions/statement_restrictions.hh" #include "index/secondary_index.hh" -#include "types/vector.hh" -#include "vector_search/filter.hh" #include "validation.hh" #include "exceptions/unrecognized_entity_exception.hh" #include @@ -1960,283 +1959,6 @@ mutation_fragments_select_statement::do_execute(query_processor& qp, service::qu })); } -struct ann_ordering_info { - secondary_index::index _index; - raw::select_statement::prepared_ann_ordering_type _prepared_ann_ordering; - bool is_rescoring_enabled; -}; - -static std::optional get_ann_ordering_info( - data_dictionary::database db, - schema_ptr schema, - lw_shared_ptr parameters, - prepare_context& ctx) { - - if (parameters->orderings().empty()) { - return std::nullopt; - } - - auto [column_id, ordering] = parameters->orderings().front(); - const auto& ann_vector = std::get_if(&ordering); - if (!ann_vector) { - return std::nullopt; - } - - ::shared_ptr column = column_id->prepare_column_identifier(*schema); - const column_definition* def = schema->get_column_definition(column->name()); - if (!def) { - throw exceptions::invalid_request_exception( - fmt::format("Undefined column name {}", column->text())); - } - - if (!def->type->is_vector() || static_cast(def->type.get())->get_elements_type()->get_kind() != abstract_type::kind::float_kind) { - throw exceptions::invalid_request_exception("ANN ordering is only supported on float vector indexes"); - } - - auto e = expr::prepare_expression(*ann_vector, db, schema->ks_name(), nullptr, def->column_specification); - expr::fill_prepare_context(e, ctx); - - raw::select_statement::prepared_ann_ordering_type prepared_ann_ordering = std::make_pair(std::move(def), std::move(e)); - - auto cf = db.find_column_family(schema); - auto& sim = cf.get_index_manager(); - - auto indexes = sim.list_indexes(); - auto it = std::find_if(indexes.begin(), indexes.end(), [&prepared_ann_ordering](const auto& ind) { - return secondary_index::vector_index::is_vector_index_on_column(ind.metadata(), prepared_ann_ordering.first->name_as_text()); - }); - - if (it == indexes.end()) { - throw exceptions::invalid_request_exception("ANN ordering by vector requires the column to be indexed using 'vector_index'"); - } - - return ann_ordering_info{ - *it, - std::move(prepared_ann_ordering), - secondary_index::vector_index::is_rescoring_enabled(it->metadata().options()) - }; -} - -static uint32_t add_similarity_function_to_selectors( - std::vector& prepared_selectors, - const ann_ordering_info& ann_ordering_info, - data_dictionary::database db, - schema_ptr schema) { - auto similarity_function_name = secondary_index::vector_index::get_cql_similarity_function_name(ann_ordering_info._index.metadata().options()); - // Create the function name - auto func_name = functions::function_name::native_function(sstring(similarity_function_name)); - - // Create the function arguments - std::vector args; - args.push_back(expr::column_value(ann_ordering_info._prepared_ann_ordering.first)); - args.push_back(ann_ordering_info._prepared_ann_ordering.second); - - // Get the function object - std::vector> provided_args; - provided_args.push_back(expr::as_assignment_testable(args[0], expr::type_of(args[0]))); - provided_args.push_back(expr::as_assignment_testable(args[1], expr::type_of(args[1]))); - - auto func = cql3::functions::instance().get(db, schema->ks_name(), func_name, provided_args, schema->ks_name(), schema->cf_name(), nullptr); - - // Create the function call expression - expr::function_call similarity_func_call{ - .func = func, - .args = std::move(args), - }; - - // Add the similarity function as a prepared selector (last) - prepared_selectors.push_back(selection::prepared_selector{ - .expr = std::move(similarity_func_call), - .alias = nullptr, - }); - return prepared_selectors.size() - 1; -} - -static select_statement::ordering_comparator_type get_similarity_ordering_comparator(std::vector& prepared_selectors, uint32_t similarity_column_index) { - auto type = expr::type_of(prepared_selectors[similarity_column_index].expr); - if (type->get_kind() != abstract_type::kind::float_kind) { - seastar::on_internal_error(logger, "Similarity function must return float type."); - } - return [similarity_column_index, type] (const raw::select_statement::result_row_type& r1, const raw::select_statement::result_row_type& r2) { - auto& c1 = r1[similarity_column_index]; - auto& c2 = r2[similarity_column_index]; - auto f1 = c1 ? value_cast(type->deserialize(*c1)) : std::numeric_limits::quiet_NaN(); - auto f2 = c2 ? value_cast(type->deserialize(*c2)) : std::numeric_limits::quiet_NaN(); - if (std::isfinite(f1) && std::isfinite(f2)) { - return f1 > f2; - } - return std::isfinite(f1); - }; -} - -::shared_ptr vector_indexed_table_select_statement::prepare(data_dictionary::database db, schema_ptr schema, - uint32_t bound_terms, lw_shared_ptr parameters, ::shared_ptr selection, - ::shared_ptr restrictions, ::shared_ptr> group_by_cell_indices, bool is_reversed, - ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, - std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, std::unique_ptr attrs) { - - auto prepared_filter = vector_search::prepare_filter(*restrictions, parameters->allow_filtering()); - - return ::make_shared(schema, bound_terms, parameters, std::move(selection), std::move(restrictions), - std::move(group_by_cell_indices), is_reversed, std::move(ordering_comparator), std::move(prepared_ann_ordering), std::move(limit), - std::move(per_partition_limit), stats, index, std::move(prepared_filter), std::move(attrs)); -} - -vector_indexed_table_select_statement::vector_indexed_table_select_statement(schema_ptr schema, uint32_t bound_terms, lw_shared_ptr parameters, - ::shared_ptr selection, ::shared_ptr restrictions, - ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, - prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, - std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, - vector_search::prepared_filter prepared_filter, std::unique_ptr attrs) - : select_statement{schema, bound_terms, parameters, selection, restrictions, group_by_cell_indices, is_reversed, ordering_comparator, limit, - per_partition_limit, stats, std::move(attrs)} - , _index{index} - , _prepared_ann_ordering(std::move(prepared_ann_ordering)) - , _prepared_filter(std::move(prepared_filter)) { - - if (!limit.has_value()) { - throw exceptions::invalid_request_exception("Vector ANN queries must have a limit specified"); - } - - if (per_partition_limit.has_value()) { - throw exceptions::invalid_request_exception("Vector ANN queries do not support per-partition limits"); - } - - if (selection->is_aggregate()) { - throw exceptions::invalid_request_exception("Vector ANN queries cannot be run with aggregation"); - } -} - -future> vector_indexed_table_select_statement::do_execute( - query_processor& qp, service::query_state& state, const query_options& options) const { - - auto limit = get_limit(options, _limit); - - auto result = co_await measure_index_latency(*_schema, _index, [this, &qp, &state, &options, &limit](this auto) -> future> { - tracing::add_table_name(state.get_trace_state(), keyspace(), column_family()); - validate_for_read(options.get_consistency()); - - _query_start_time_point = gc_clock::now(); - - update_stats(); - - if (limit > max_ann_query_limit) { - co_await coroutine::return_exception(exceptions::invalid_request_exception( - fmt::format("Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than {}. LIMIT was {}", max_ann_query_limit, limit))); - } - - auto timeout = db::timeout_clock::now() + get_timeout(state.get_client_state(), options); - auto aoe = abort_on_expiry(timeout); - auto filter_json = _prepared_filter.to_json(options); - uint64_t fetch = static_cast(std::ceil(limit * secondary_index::vector_index::get_oversampling(_index.metadata().options()))); - auto pkeys = co_await qp.vector_store_client().ann( - _schema->ks_name(), _index.metadata().name(), _schema, get_ann_ordering_vector(options), fetch, filter_json, aoe.abort_source()); - if (!pkeys.has_value()) { - co_await coroutine::return_exception( - exceptions::invalid_request_exception(std::visit(vector_search::vector_store_client::ann_error_visitor{}, pkeys.error()))); - } - - if (pkeys->size() > limit && !secondary_index::vector_index::is_rescoring_enabled(_index.metadata().options())) { - pkeys->erase(pkeys->begin() + limit, pkeys->end()); - } - - co_return co_await query_base_table(qp, state, options, pkeys.value(), timeout); - }); - - auto page_size = options.get_page_size(); - if (page_size > 0 && (uint64_t) page_size < limit) { - result->add_warning("Paging is not supported for Vector Search queries. The entire result set has been returned."); - } - co_return result; -} - -void vector_indexed_table_select_statement::update_stats() const { - ++_stats.secondary_index_reads; - ++_stats.query_cnt(source_selector::USER, _ks_sel, cond_selector::NO_CONDITIONS, statement_type::SELECT); -} - -lw_shared_ptr vector_indexed_table_select_statement::prepare_command_for_base_query( - query_processor& qp, service::query_state& state, const query_options& options, uint64_t fetch_limit) const { - auto slice = make_partition_slice(options); - return ::make_lw_shared(_schema->id(), _schema->version(), std::move(slice), qp.proxy().get_max_result_size(slice), - query::tombstone_limit(qp.proxy().get_tombstone_limit()), - query::row_limit(get_inner_loop_limit(fetch_limit, _selection->is_aggregate())), query::partition_limit(query::max_partitions), - _query_start_time_point, tracing::make_trace_info(state.get_trace_state()), query_id::create_null_id(), query::is_first_page::no, - options.get_timestamp(state)); -} - -std::vector vector_indexed_table_select_statement::get_ann_ordering_vector(const query_options& options) const { - auto [ann_column, ann_vector_expr] = _prepared_ann_ordering; - auto expr_value = expr::evaluate(ann_vector_expr, options); - if (expr_value.is_null()) { - throw exceptions::invalid_request_exception(fmt::format("Unsupported null value for column {}", _prepared_ann_ordering.first->name_as_text())); - } - auto values = value_cast(ann_column->type->deserialize(std::move(expr_value).to_bytes())); - return util::to_vector(values); -} - -future<::shared_ptr> vector_indexed_table_select_statement::query_base_table(query_processor& qp, - service::query_state& state, const query_options& options, const std::vector& pkeys, - lowres_clock::time_point timeout) const { - auto command = prepare_command_for_base_query(qp, state, options, pkeys.size()); - - auto result = co_await query_base_table(qp, state, options, command, timeout, pkeys); - - command->set_row_limit(get_limit(options, _limit)); - - co_return co_await wrap_result_to_error_message([this, command = std::move(command), &options](auto query_result) { - return process_results(std::move(query_result), command, options, _query_start_time_point); - })(std::move(result)); -} - -future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, - service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, - const std::vector& pkeys) const { - - // For tables without clustering columns, we can optimize by querying - // partition ranges instead of individual primary keys, since the - // partition key alone uniquely identifies each row. - if (_schema->clustering_key_size() == 0) { - auto to_partition_ranges = [](const std::vector& pkeys) -> std::vector { - std::vector partition_ranges; - std::ranges::transform(pkeys, std::back_inserter(partition_ranges), [](const auto& pkey) { - return dht::partition_range::make_singular(pkey.partition); - }); - - return partition_ranges; - }; - co_return co_await query_base_table(qp, state, options, std::move(command), timeout, to_partition_ranges(pkeys)); - } - co_return co_await utils::result_map_reduce( - pkeys.begin(), pkeys.end(), - [&](this auto, auto& key) -> future>>> { - auto cmd = ::make_lw_shared(*command); - cmd->slice._row_ranges = query::clustering_row_ranges{query::clustering_range::make_singular(key.clustering)}; - coordinator_result rqr = - co_await qp.proxy().query_result(_schema, cmd, {dht::partition_range::make_singular(key.partition)}, options.get_consistency(), - {timeout, state.get_permit(), state.get_client_state(), state.get_trace_state()}); - if (!rqr) { - co_return std::move(rqr).as_failure(); - } - co_return std::move(rqr.value().query_result); - }, - query::result_merger{command->get_row_limit(), query::max_partitions}); -} - -future>>> vector_indexed_table_select_statement::query_base_table(query_processor& qp, - service::query_state& state, const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, - std::vector partition_ranges) const { - - coordinator_result rqr = co_await qp.proxy() - .query_result(_query_schema, command, std::move(partition_ranges), options.get_consistency(), - {timeout, state.get_permit(), state.get_client_state(), state.get_trace_state(), {}, {}, options.get_specific_options().node_local_only}, - std::nullopt); - if (!rqr) { - co_return std::move(rqr).as_failure(); - } - co_return std::move(rqr.value().query_result); -} - namespace raw { static void validate_attrs(const cql3::attributes::raw& attrs) { diff --git a/cql3/statements/select_statement.hh b/cql3/statements/select_statement.hh index 262e2510eb..2b61a9c154 100644 --- a/cql3/statements/select_statement.hh +++ b/cql3/statements/select_statement.hh @@ -21,8 +21,6 @@ #include "exceptions/coordinator_result.hh" #include "locator/host_id.hh" #include "service/cas_shard.hh" -#include "vector_search/vector_store_client.hh" -#include "vector_search/filter.hh" namespace service { class client_state; @@ -361,48 +359,5 @@ private: }; -class vector_indexed_table_select_statement : public select_statement { - secondary_index::index _index; - prepared_ann_ordering_type _prepared_ann_ordering; - vector_search::prepared_filter _prepared_filter; - mutable gc_clock::time_point _query_start_time_point; - -public: - static constexpr size_t max_ann_query_limit = 1000; - - static ::shared_ptr prepare(data_dictionary::database db, schema_ptr schema, uint32_t bound_terms, - lw_shared_ptr parameters, ::shared_ptr selection, - ::shared_ptr restrictions, ::shared_ptr> group_by_cell_indices, bool is_reversed, - ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, - std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, std::unique_ptr attrs); - - vector_indexed_table_select_statement(schema_ptr schema, uint32_t bound_terms, lw_shared_ptr parameters, - ::shared_ptr selection, ::shared_ptr restrictions, - ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, - prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, std::optional per_partition_limit, - cql_stats& stats, const secondary_index::index& index, vector_search::prepared_filter prepared_filter, std::unique_ptr attrs); - -private: - future<::shared_ptr> do_execute( - query_processor& qp, service::query_state& state, const query_options& options) const override; - - void update_stats() const; - - lw_shared_ptr prepare_command_for_base_query(query_processor& qp, service::query_state& state, const query_options& options, uint64_t fetch_limit) const; - - std::vector get_ann_ordering_vector(const query_options& options) const; - - future<::shared_ptr> query_base_table(query_processor& qp, service::query_state& state, - const query_options& options, const std::vector& pkeys, lowres_clock::time_point timeout) const; - - future>>> query_base_table(query_processor& qp, service::query_state& state, - const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, - const std::vector& pkeys) const; - - future>>> query_base_table(query_processor& qp, service::query_state& state, - const query_options& options, lw_shared_ptr command, lowres_clock::time_point timeout, - std::vector partition_ranges) const; -}; - } } From ed1006928f119ad190f49099d49bfc06c01d3c13 Mon Sep 17 00:00:00 2001 From: Szymon Malewski Date: Tue, 26 May 2026 16:08:24 +0200 Subject: [PATCH 3/3] vector_index: move filter into cql3/statements/external_search Move prepared_filter, prepared_restriction, prepared_rhs types and prepare_filter() from vector_search/filter.{hh,cc} into new files cql3/statements/external_search/filter.{hh,cc} under namespace cql3::statements::external_search. This eliminates a circular dependency between the cql3 and vector_search modules: the filter code depends heavily on cql3 types (expressions, query_options, statement_restrictions) and belongs in the cql3 layer. This is a follow-up to VECTOR-250 which originally addressed the same circular dependency but has since regressed. --- configure.py | 2 +- cql3/CMakeLists.txt | 1 + .../statements/external_search}/filter.cc | 91 ++++++++++--------- .../statements/external_search}/filter.hh | 10 +- .../vector_indexed_table_select_statement.cc | 4 +- .../vector_indexed_table_select_statement.hh | 6 +- test/vector_search/filter_test.cc | 30 +++--- vector_search/CMakeLists.txt | 1 - 8 files changed, 77 insertions(+), 68 deletions(-) rename {vector_search => cql3/statements/external_search}/filter.cc (68%) rename {vector_search => cql3/statements/external_search}/filter.hh (85%) diff --git a/configure.py b/configure.py index 7c4145b8f9..775e5c12d7 100755 --- a/configure.py +++ b/configure.py @@ -1062,6 +1062,7 @@ scylla_core = (['message/messaging_service.cc', 'cql3/statements/batch_statement.cc', 'cql3/statements/select_statement.cc', 'cql3/statements/external_search/vector_indexed_table_select_statement.cc', + 'cql3/statements/external_search/filter.cc', 'cql3/statements/use_statement.cc', 'cql3/statements/index_prop_defs.cc', 'cql3/statements/index_target.cc', @@ -1388,7 +1389,6 @@ scylla_core = (['message/messaging_service.cc', 'vector_search/dns.cc', 'vector_search/client.cc', 'vector_search/clients.cc', - 'vector_search/filter.cc', 'vector_search/truststore.cc' ] + [Antlr3Grammar('cql3/Cql.g')] \ + scylla_raft_core diff --git a/cql3/CMakeLists.txt b/cql3/CMakeLists.txt index 3d4c3f7c4e..086f6f230f 100644 --- a/cql3/CMakeLists.txt +++ b/cql3/CMakeLists.txt @@ -82,6 +82,7 @@ target_sources(cql3 statements/batch_statement.cc statements/select_statement.cc statements/external_search/vector_indexed_table_select_statement.cc + statements/external_search/filter.cc statements/use_statement.cc statements/index_prop_defs.cc statements/index_target.cc diff --git a/vector_search/filter.cc b/cql3/statements/external_search/filter.cc similarity index 68% rename from vector_search/filter.cc rename to cql3/statements/external_search/filter.cc index 05dee790b5..8f12172918 100644 --- a/vector_search/filter.cc +++ b/cql3/statements/external_search/filter.cc @@ -6,7 +6,8 @@ * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1 */ -#include "vector_search/filter.hh" +#include "cql3/statements/external_search/filter.hh" + #include "cql3/restrictions/statement_restrictions.hh" #include "cql3/query_options.hh" #include "cql3/expr/expr-utils.hh" @@ -14,49 +15,53 @@ #include "types/json_utils.hh" #include "utils/big_decimal.hh" -namespace vector_search { +namespace cql3 { + +namespace statements { + +namespace external_search { namespace { -std::optional to_single_column_op_string(cql3::expr::oper_t op) { +std::optional to_single_column_op_string(expr::oper_t op) { switch (op) { - case cql3::expr::oper_t::EQ: + case expr::oper_t::EQ: return "=="; - case cql3::expr::oper_t::LT: + case expr::oper_t::LT: return "<"; - case cql3::expr::oper_t::LTE: + case expr::oper_t::LTE: return "<="; - case cql3::expr::oper_t::GT: + case expr::oper_t::GT: return ">"; - case cql3::expr::oper_t::GTE: + case expr::oper_t::GTE: return ">="; - case cql3::expr::oper_t::IN: + case expr::oper_t::IN: return "IN"; default: return std::nullopt; } } -std::optional to_multi_column_op_string(cql3::expr::oper_t op) { +std::optional to_multi_column_op_string(expr::oper_t op) { switch (op) { - case cql3::expr::oper_t::EQ: + case expr::oper_t::EQ: return "()==()"; - case cql3::expr::oper_t::LT: + case expr::oper_t::LT: return "()<()"; - case cql3::expr::oper_t::LTE: + case expr::oper_t::LTE: return "()<=()"; - case cql3::expr::oper_t::GT: + case expr::oper_t::GT: return "()>()"; - case cql3::expr::oper_t::GTE: + case expr::oper_t::GTE: return "()>=()"; - case cql3::expr::oper_t::IN: + case expr::oper_t::IN: return "()IN()"; default: return std::nullopt; } } -rjson::value value_to_json(const data_type& type, const cql3::raw_value& val) { +rjson::value value_to_json(const data_type& type, const raw_value& val) { if (val.is_null()) { return rjson::null_value(); } @@ -79,33 +84,33 @@ rjson::value value_to_json(const data_type& type, const cql3::raw_value& val) { return rjson::parse(json_str); } -rjson::value lhs_to_json(const cql3::expr::column_value& col) { +rjson::value lhs_to_json(const expr::column_value& col) { return rjson::from_string(col.col->name_as_text()); } -rjson::value lhs_to_json(const cql3::expr::tuple_constructor& lhs_tuple) { +rjson::value lhs_to_json(const expr::tuple_constructor& lhs_tuple) { auto arr = rjson::empty_array(); for (const auto& elem : lhs_tuple.elements) { - if (auto* cv = cql3::expr::as_if(&elem)) { + if (auto* cv = expr::as_if(&elem)) { rjson::push_back(arr, rjson::from_string(cv->col->name_as_text())); } } return arr; } -prepared_restriction make_prepared_restriction(const sstring& op_str, rjson::value lhs_json, const cql3::expr::expression& rhs_expr) { - auto rhs_type = cql3::expr::type_of(rhs_expr); - if (cql3::expr::contains_bind_marker(rhs_expr)) { +prepared_restriction make_prepared_restriction(const sstring& op_str, rjson::value lhs_json, const expr::expression& rhs_expr) { + auto rhs_type = expr::type_of(rhs_expr); + if (expr::contains_bind_marker(rhs_expr)) { return prepared_restriction{ .type_json = rjson::from_string(op_str), .lhs_json = std::move(lhs_json), .rhs = prepared_rhs{std::move(rhs_type), rhs_expr}}; } else { - auto rhs_val = cql3::expr::evaluate(rhs_expr, cql3::query_options({})); + auto rhs_val = expr::evaluate(rhs_expr, query_options({})); return prepared_restriction{.type_json = rjson::from_string(op_str), .lhs_json = std::move(lhs_json), .rhs = value_to_json(rhs_type, rhs_val)}; } } void single_column_restriction_to_prepared( - const cql3::expr::binary_operator& binop, const cql3::expr::column_value& col, std::vector& restrictions) { + const expr::binary_operator& binop, const expr::column_value& col, std::vector& restrictions) { auto op_str = to_single_column_op_string(binop.op); if (!op_str) { throw exceptions::unsupported_operation_exception(sstring("Unsupported operator in restriction on column ") + col.col->name_as_text()); @@ -115,7 +120,7 @@ void single_column_restriction_to_prepared( } void multi_column_restriction_to_prepared( - const cql3::expr::binary_operator& binop, const cql3::expr::tuple_constructor& lhs_tuple, std::vector& restrictions) { + const expr::binary_operator& binop, const expr::tuple_constructor& lhs_tuple, std::vector& restrictions) { auto op_str = to_multi_column_op_string(binop.op); if (!op_str) { throw exceptions::unsupported_operation_exception(sstring("Unsupported operator in restriction on columns ") + to_string(lhs_tuple)); @@ -124,25 +129,25 @@ void multi_column_restriction_to_prepared( restrictions.push_back(make_prepared_restriction(*op_str, lhs_to_json(lhs_tuple), binop.rhs)); } -void binary_operator_to_prepared(const cql3::expr::binary_operator& binop, std::vector& restrictions) { - if (auto* cv = cql3::expr::as_if(&binop.lhs)) { +void binary_operator_to_prepared(const expr::binary_operator& binop, std::vector& restrictions) { + if (auto* cv = expr::as_if(&binop.lhs)) { single_column_restriction_to_prepared(binop, *cv, restrictions); return; } - if (auto* tuple = cql3::expr::as_if(&binop.lhs)) { + if (auto* tuple = expr::as_if(&binop.lhs)) { multi_column_restriction_to_prepared(binop, *tuple, restrictions); return; } } -void expression_to_prepared(const cql3::expr::expression& expr, std::vector& restrictions) { - cql3::expr::for_each_expression(expr, [&](const cql3::expr::binary_operator& binop) { +void expression_to_prepared(const expr::expression& expr, std::vector& restrictions) { + expr::for_each_expression(expr, [&](const expr::binary_operator& binop) { binary_operator_to_prepared(binop, restrictions); }); } -rjson::value restriction_to_json(const prepared_restriction& r, const cql3::query_options& options) { +rjson::value restriction_to_json(const prepared_restriction& r, const query_options& options) { auto obj = rjson::empty_object(); rjson::add(obj, "type", rjson::copy(r.type_json)); rjson::add(obj, "lhs", rjson::copy(r.lhs_json)); @@ -150,7 +155,7 @@ rjson::value restriction_to_json(const prepared_restriction& r, const cql3::quer return obj; } -rjson::value restrictions_to_json(const std::vector& restrictions, bool allow_filtering, const cql3::query_options& options) { +rjson::value restrictions_to_json(const std::vector& restrictions, bool allow_filtering, const query_options& options) { auto result = rjson::empty_object(); if (restrictions.empty() && !allow_filtering) { @@ -170,7 +175,7 @@ rjson::value restrictions_to_json(const std::vector& restr } // anonymous namespace -rjson::value prepared_restriction::rhs_to_json(const cql3::query_options& options) const { +rjson::value prepared_restriction::rhs_to_json(const query_options& options) const { return std::visit( [&](const auto& v) -> rjson::value { using T = std::decay_t; @@ -178,14 +183,14 @@ rjson::value prepared_restriction::rhs_to_json(const cql3::query_options& option return rjson::copy(v); } else { const auto& [type, expr] = v; - auto val = cql3::expr::evaluate(expr, options); + auto val = expr::evaluate(expr, options); return value_to_json(type, val); } }, rhs); } -rjson::value prepared_filter::to_json(const cql3::query_options& options) const { +rjson::value prepared_filter::to_json(const query_options& options) const { if (_cached_json) { return rjson::copy(_cached_json.value()); } @@ -193,7 +198,7 @@ rjson::value prepared_filter::to_json(const cql3::query_options& options) const return restrictions_to_json(_restrictions, _allow_filtering, options); } -prepared_filter prepare_filter(const cql3::restrictions::statement_restrictions& restrictions, bool allow_filtering) { +prepared_filter prepare_filter(const restrictions::statement_restrictions& restrictions, bool allow_filtering) { if (restrictions.is_empty()) { return prepared_filter({}, allow_filtering); } @@ -208,9 +213,9 @@ prepared_filter prepare_filter(const cql3::restrictions::statement_restrictions& expression_to_prepared(clustering_columns_restrictions, prepared_restrictions); expression_to_prepared(nonprimary_key_restrictions, prepared_restrictions); - bool has_bind_markers = cql3::expr::contains_bind_marker(partition_key_restrictions) - || cql3::expr::contains_bind_marker(clustering_columns_restrictions) - || cql3::expr::contains_bind_marker(nonprimary_key_restrictions); + bool has_bind_markers = expr::contains_bind_marker(partition_key_restrictions) + || expr::contains_bind_marker(clustering_columns_restrictions) + || expr::contains_bind_marker(nonprimary_key_restrictions); if (!has_bind_markers) { auto cached_json = restrictions_to_json(prepared_restrictions, allow_filtering, cql3::query_options({})); @@ -220,4 +225,8 @@ prepared_filter prepare_filter(const cql3::restrictions::statement_restrictions& return prepared_filter(std::move(prepared_restrictions), allow_filtering); } -} // namespace vector_search +} // namespace external_search + +} // namespace statements + +} // namespace cql3 diff --git a/vector_search/filter.hh b/cql3/statements/external_search/filter.hh similarity index 85% rename from vector_search/filter.hh rename to cql3/statements/external_search/filter.hh index 7a7ff4782b..608d4e0899 100644 --- a/vector_search/filter.hh +++ b/cql3/statements/external_search/filter.hh @@ -22,11 +22,11 @@ class statement_restrictions; } } // namespace cql3 -namespace vector_search { +namespace cql3::statements::external_search { struct prepared_rhs { data_type type; - cql3::expr::expression expr; + expr::expression expr; }; struct prepared_restriction { @@ -34,7 +34,7 @@ struct prepared_restriction { rjson::value lhs_json; std::variant rhs; - rjson::value rhs_to_json(const cql3::query_options& options) const; + rjson::value rhs_to_json(const query_options& options) const; }; class prepared_filter { @@ -51,7 +51,7 @@ public: } /// Serializes the prepared filter to JSON compatible with the Vector Store service filtering API. - rjson::value to_json(const cql3::query_options& options) const; + rjson::value to_json(const query_options& options) const; }; /// Prepares a filter from CQL statement restrictions for use in Vector Store service. @@ -59,4 +59,4 @@ public: /// and prepares them for serialization to JSON compatible to Vector Store service filtering API. prepared_filter prepare_filter(const cql3::restrictions::statement_restrictions& restrictions, bool allow_filtering); -} // namespace vector_search +} // namespace cql3::statements::external_search diff --git a/cql3/statements/external_search/vector_indexed_table_select_statement.cc b/cql3/statements/external_search/vector_indexed_table_select_statement.cc index 469d653897..8f9383b655 100644 --- a/cql3/statements/external_search/vector_indexed_table_select_statement.cc +++ b/cql3/statements/external_search/vector_indexed_table_select_statement.cc @@ -196,7 +196,7 @@ select_statement::ordering_comparator_type get_similarity_ordering_comparator(st ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, std::unique_ptr attrs) { - auto prepared_filter = vector_search::prepare_filter(*restrictions, parameters->allow_filtering()); + auto prepared_filter = external_search::prepare_filter(*restrictions, parameters->allow_filtering()); return ::make_shared(schema, bound_terms, parameters, std::move(selection), std::move(restrictions), std::move(group_by_cell_indices), is_reversed, std::move(ordering_comparator), std::move(prepared_ann_ordering), std::move(limit), @@ -208,7 +208,7 @@ vector_indexed_table_select_statement::vector_indexed_table_select_statement(sch ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, std::optional per_partition_limit, cql_stats& stats, const secondary_index::index& index, - vector_search::prepared_filter prepared_filter, std::unique_ptr attrs) + external_search::prepared_filter prepared_filter, std::unique_ptr attrs) : select_statement{schema, bound_terms, parameters, selection, restrictions, group_by_cell_indices, is_reversed, ordering_comparator, limit, per_partition_limit, stats, std::move(attrs)} , _index{index} diff --git a/cql3/statements/external_search/vector_indexed_table_select_statement.hh b/cql3/statements/external_search/vector_indexed_table_select_statement.hh index cd6239f7d0..e27997b67c 100644 --- a/cql3/statements/external_search/vector_indexed_table_select_statement.hh +++ b/cql3/statements/external_search/vector_indexed_table_select_statement.hh @@ -9,8 +9,8 @@ #pragma once #include "cql3/statements/select_statement.hh" +#include "cql3/statements/external_search/filter.hh" #include "vector_search/vector_store_client.hh" -#include "vector_search/filter.hh" #include @@ -47,7 +47,7 @@ select_statement::ordering_comparator_type get_similarity_ordering_comparator( class vector_indexed_table_select_statement : public select_statement { secondary_index::index _index; prepared_ann_ordering_type _prepared_ann_ordering; - vector_search::prepared_filter _prepared_filter; + external_search::prepared_filter _prepared_filter; mutable gc_clock::time_point _query_start_time_point; public: @@ -63,7 +63,7 @@ public: ::shared_ptr selection, ::shared_ptr restrictions, ::shared_ptr> group_by_cell_indices, bool is_reversed, ordering_comparator_type ordering_comparator, prepared_ann_ordering_type prepared_ann_ordering, std::optional limit, std::optional per_partition_limit, - cql_stats& stats, const secondary_index::index& index, vector_search::prepared_filter prepared_filter, std::unique_ptr attrs); + cql_stats& stats, const secondary_index::index& index, external_search::prepared_filter prepared_filter, std::unique_ptr attrs); private: future<::shared_ptr> do_execute( diff --git a/test/vector_search/filter_test.cc b/test/vector_search/filter_test.cc index 5cead199cd..6bc50d6f4d 100644 --- a/test/vector_search/filter_test.cc +++ b/test/vector_search/filter_test.cc @@ -16,7 +16,7 @@ #include "types/types.hh" #include "utils/big_decimal.hh" #include "utils/rjson.hh" -#include "vector_search/filter.hh" +#include "cql3/statements/external_search/filter.hh" BOOST_AUTO_TEST_SUITE(filter_test) @@ -55,7 +55,7 @@ query_options make_query_options(std::vector values) { /// Helper to get JSON string from restrictions sstring get_restrictions_json(const restrictions::statement_restrictions& restr, bool allow_filtering = false) { - return rjson::print(vector_search::prepare_filter(restr, allow_filtering).to_json(query_options({}))); + return rjson::print(statements::external_search::prepare_filter(restr, allow_filtering).to_json(query_options({}))); } } // anonymous namespace @@ -66,7 +66,7 @@ SEASTAR_TEST_CASE(to_json_empty_restrictions) { auto schema = e.local_db().find_schema("ks", "t"); shared_ptr restr = restrictions::make_trivial_statement_restrictions(schema, false); - auto json = rjson::print(vector_search::prepare_filter(*restr, false).to_json(query_options({}))); + auto json = rjson::print(statements::external_search::prepare_filter(*restr, false).to_json(query_options({}))); BOOST_CHECK_EQUAL(json, "{}"); }); @@ -269,7 +269,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_partition_key) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=?", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); std::vector bind_values = {raw_value::make_value(int32_type->decompose(42))}; auto options = make_query_options(std::move(bind_values)); @@ -285,7 +285,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_clustering_key) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=? and ck>?", e); - auto filter = vector_search::prepare_filter(*restr, true); + auto filter = statements::external_search::prepare_filter(*restr, true); std::vector bind_values = { raw_value::make_value(int32_type->decompose(1)), @@ -303,7 +303,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_different_values) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=?", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); std::vector bind_values1 = {raw_value::make_value(int32_type->decompose(100))}; auto options1 = make_query_options(std::move(bind_values1)); @@ -324,7 +324,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_string_value) { cquery_nofail(e, "create table ks.t(pk text, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=?", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); std::vector bind_values = {raw_value::make_value(utf8_type->decompose("hello_world"))}; auto options = make_query_options(std::move(bind_values)); @@ -340,7 +340,7 @@ SEASTAR_TEST_CASE(to_json_mixed_literals_and_bind_markers) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=1 and ck>?", e); - auto filter = vector_search::prepare_filter(*restr, true); + auto filter = statements::external_search::prepare_filter(*restr, true); std::vector bind_values = {raw_value::make_value(int32_type->decompose(25))}; auto options = make_query_options(std::move(bind_values)); @@ -356,7 +356,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_in_list) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=1 and ck in ?", e); - auto filter = vector_search::prepare_filter(*restr, true); + auto filter = statements::external_search::prepare_filter(*restr, true); auto list_type = list_type_impl::get_instance(int32_type, true); auto list_val = make_list_value(list_type, {data_value(10), data_value(20), data_value(30)}); @@ -375,7 +375,7 @@ SEASTAR_TEST_CASE(to_json_bind_marker_multi_column) { cquery_nofail(e, "create table ks.t(pk int, ck1 int, ck2 int, v vector, primary key(pk, ck1, ck2))"); auto restr = make_restrictions("pk=1 and (ck1, ck2)>?", e); - auto filter = vector_search::prepare_filter(*restr, true); + auto filter = statements::external_search::prepare_filter(*restr, true); auto tuple_type = tuple_type_impl::get_instance({int32_type, int32_type}); auto tuple_val = make_tuple_value(tuple_type, {data_value(10), data_value(20)}); @@ -394,7 +394,7 @@ SEASTAR_TEST_CASE(to_json_no_bind_markers_uses_cache) { cquery_nofail(e, "create table ks.t(pk int, ck int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=42", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); auto options1 = query_options({}); auto json1 = rjson::print(filter.to_json(options1)); @@ -438,7 +438,7 @@ SEASTAR_TEST_CASE(to_json_nonprimary_key_bind_marker) { cquery_nofail(e, "create table ks.t(pk int, ck int, r int, v vector, primary key(pk, ck))"); auto restr = make_restrictions("pk=1 and r=?", e); - auto filter = vector_search::prepare_filter(*restr, true); + auto filter = statements::external_search::prepare_filter(*restr, true); std::vector bind_values = {raw_value::make_value(int32_type->decompose(99))}; auto options = make_query_options(std::move(bind_values)); @@ -469,7 +469,7 @@ SEASTAR_TEST_CASE(to_json_decimal) { // Same value via bind marker. restr = make_restrictions("pk=?", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); std::vector bind_values = { raw_value::make_value(decimal_type->decompose(big_decimal("98765432109876543210.12345")))}; auto options = make_query_options(std::move(bind_values)); @@ -486,7 +486,7 @@ SEASTAR_TEST_CASE(to_json_decimal) { // "1.230", not be normalized to "1.23". These are different partition // keys because the wire format differs. restr = make_restrictions("pk=?", e); - auto filter2 = vector_search::prepare_filter(*restr, false); + auto filter2 = statements::external_search::prepare_filter(*restr, false); std::vector bind_values2 = { raw_value::make_value(decimal_type->decompose(big_decimal("1.230")))}; auto options2 = make_query_options(std::move(bind_values2)); @@ -508,7 +508,7 @@ SEASTAR_TEST_CASE(to_json_varint) { // Same value via bind marker. restr = make_restrictions("pk=?", e); - auto filter = vector_search::prepare_filter(*restr, false); + auto filter = statements::external_search::prepare_filter(*restr, false); std::vector bind_values = { raw_value::make_value(varint_type->decompose(utils::multiprecision_int("98765432109876543210")))}; auto options = make_query_options(std::move(bind_values)); diff --git a/vector_search/CMakeLists.txt b/vector_search/CMakeLists.txt index 2a1b18426b..32cd983278 100644 --- a/vector_search/CMakeLists.txt +++ b/vector_search/CMakeLists.txt @@ -5,7 +5,6 @@ target_sources(vector_search dns.cc client.cc clients.cc - filter.cc truststore.cc) target_link_libraries(vector_search PUBLIC