From ead996178317cba2a079cae051117a257b70cdc6 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Mon, 19 Jan 2026 09:44:02 +0200 Subject: [PATCH] cql: vector: fix vector dimension type Switch vector dimension handling to fixed-width `uint32_t` type, update parsing/validation, and add boundary tests. The dimension is parsed as `unsigned long` at first which is guaranteed to be **at least** 32-bit long, which is safe to downcast to `uint32_t`. Move `MAX_VECTOR_DIMENSION` from `cql3_type::raw_vector` to `cql3_type` to ensure public visibility for checks outside the class. Add tests to verify the type boundaries. Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-223 Signed-off-by: Yaniv Kaul Co-authored-by: Dawid Pawlik Closes scylladb/scylladb#28762 --- cql3/Cql.g | 16 ++++++++- cql3/assignment_testable.hh | 2 +- cql3/cql3_type.cc | 9 ++--- cql3/cql3_type.hh | 5 ++- cql3/expr/prepare_expr.cc | 4 +-- cql3/functions/vector_similarity_fcts.cc | 6 ++-- cql3/functions/vector_similarity_fcts.hh | 2 +- db/marshal/type_parser.cc | 20 +++++++++-- db/marshal/type_parser.hh | 2 +- lang/lua.cc | 2 +- test/cqlpy/test_type_vector.py | 43 +++++++++++++++++++++++- test/lib/expr_test_utils.cc | 2 +- test/lib/expr_test_utils.hh | 2 +- types/types.cc | 10 +++--- types/types.hh | 2 ++ types/vector.hh | 14 ++++---- 16 files changed, 106 insertions(+), 35 deletions(-) diff --git a/cql3/Cql.g b/cql3/Cql.g index d7020c0a32..c54dd52d06 100644 --- a/cql3/Cql.g +++ b/cql3/Cql.g @@ -2071,7 +2071,21 @@ vector_type returns [shared_ptr pt] { if ($d.text[0] == '-') throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0"); - $pt = cql3::cql3_type::raw::vector(t, std::stoul($d.text)); + unsigned long parsed_dimension; + try { + parsed_dimension = std::stoul($d.text); + } catch (const std::exception& e) { + throw exceptions::invalid_request_exception(format("Invalid vector dimension: {}", $d.text)); + } + static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t)); + if (parsed_dimension == 0) { + throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0"); + } + if (parsed_dimension > cql3::cql3_type::MAX_VECTOR_DIMENSION) { + throw exceptions::invalid_request_exception( + format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION)); + } + $pt = cql3::cql3_type::raw::vector(t, static_cast(parsed_dimension)); } ; diff --git a/cql3/assignment_testable.hh b/cql3/assignment_testable.hh index 4db5dd476b..04b820a710 100644 --- a/cql3/assignment_testable.hh +++ b/cql3/assignment_testable.hh @@ -27,7 +27,7 @@ public: struct vector_test_result { test_result result; - std::optional dimension_opt; + std::optional dimension_opt; }; static bool is_assignable(test_result tr) { diff --git a/cql3/cql3_type.cc b/cql3/cql3_type.cc index 0af004df30..33ad293624 100644 --- a/cql3/cql3_type.cc +++ b/cql3/cql3_type.cc @@ -307,17 +307,14 @@ public: class cql3_type::raw_vector : public raw { shared_ptr _type; - size_t _dimension; - - // This limitation is acquired from the maximum number of dimensions in OpenSearch. - static constexpr size_t MAX_VECTOR_DIMENSION = 16000; + vector_dimension_t _dimension; virtual sstring to_string() const override { return seastar::format("vector<{}, {}>", _type, _dimension); } public: - raw_vector(shared_ptr type, size_t dimension) + raw_vector(shared_ptr type, vector_dimension_t dimension) : _type(std::move(type)), _dimension(dimension) { } @@ -417,7 +414,7 @@ cql3_type::raw::tuple(std::vector> ts) { } shared_ptr -cql3_type::raw::vector(shared_ptr t, size_t dimension) { +cql3_type::raw::vector(shared_ptr t, vector_dimension_t dimension) { return ::make_shared(std::move(t), dimension); } diff --git a/cql3/cql3_type.hh b/cql3/cql3_type.hh index 614948bdd5..67ea3349d8 100644 --- a/cql3/cql3_type.hh +++ b/cql3/cql3_type.hh @@ -39,6 +39,9 @@ public: data_type get_type() const { return _type; } const sstring& to_string() const { return _type->cql3_type_name(); } + // This limitation is acquired from the maximum number of dimensions in OpenSearch. + static constexpr vector_dimension_t MAX_VECTOR_DIMENSION = 16000; + // For UserTypes, we need to know the current keyspace to resolve the // actual type used, so Raw is a "not yet prepared" CQL3Type. class raw { @@ -64,7 +67,7 @@ public: static shared_ptr list(shared_ptr t); static shared_ptr set(shared_ptr t); static shared_ptr tuple(std::vector> ts); - static shared_ptr vector(shared_ptr t, size_t dimension); + static shared_ptr vector(shared_ptr t, vector_dimension_t dimension); static shared_ptr frozen(shared_ptr t); friend sstring format_as(const raw& r) { return r.to_string(); diff --git a/cql3/expr/prepare_expr.cc b/cql3/expr/prepare_expr.cc index af6dda6250..1d35530c91 100644 --- a/cql3/expr/prepare_expr.cc +++ b/cql3/expr/prepare_expr.cc @@ -502,8 +502,8 @@ vector_validate_assignable_to(const collection_constructor& c, data_dictionary:: throw exceptions::invalid_request_exception(format("Invalid vector type literal for {} of type {}", *receiver.name, receiver.type->as_cql3_type())); } - size_t expected_size = vt->get_dimension(); - if (!expected_size) { + vector_dimension_t expected_size = vt->get_dimension(); + if (expected_size == 0) { throw exceptions::invalid_request_exception(format("Invalid vector type literal for {}: type {} expects at least one element", *receiver.name, receiver.type->as_cql3_type())); } diff --git a/cql3/functions/vector_similarity_fcts.cc b/cql3/functions/vector_similarity_fcts.cc index 52adabc6a2..87e220e6d2 100644 --- a/cql3/functions/vector_similarity_fcts.cc +++ b/cql3/functions/vector_similarity_fcts.cc @@ -18,7 +18,7 @@ namespace functions { namespace detail { -std::vector extract_float_vector(const bytes_opt& param, size_t dimension) { +std::vector extract_float_vector(const bytes_opt& param, vector_dimension_t dimension) { if (!param) { throw exceptions::invalid_request_exception("Cannot extract float vector from null parameter"); } @@ -156,7 +156,7 @@ std::vector retrieve_vector_arg_types(const function_name& name, cons } } - size_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt; + vector_dimension_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt; auto type = vector_type_impl::get_instance(float_type, dimension); return {type, type}; } @@ -170,7 +170,7 @@ bytes_opt vector_similarity_fct::execute(std::span parameters) // Extract dimension from the vector type const auto& type = static_cast(*arg_types()[0]); - size_t dimension = type.get_dimension(); + vector_dimension_t dimension = type.get_dimension(); // Optimized path: extract floats directly from bytes, bypassing data_value overhead std::vector v1 = detail::extract_float_vector(parameters[0], dimension); diff --git a/cql3/functions/vector_similarity_fcts.hh b/cql3/functions/vector_similarity_fcts.hh index 662f7df1d1..708d88a21b 100644 --- a/cql3/functions/vector_similarity_fcts.hh +++ b/cql3/functions/vector_similarity_fcts.hh @@ -39,7 +39,7 @@ namespace detail { // Extract float vector directly from serialized bytes, bypassing data_value overhead. // This is an internal API exposed for testing purposes. // Vector wire format: N floats as big-endian uint32_t values, 4 bytes each. -std::vector extract_float_vector(const bytes_opt& param, size_t dimension); +std::vector extract_float_vector(const bytes_opt& param, vector_dimension_t dimension); } // namespace detail diff --git a/db/marshal/type_parser.cc b/db/marshal/type_parser.cc index f2f635534a..acbfe54d8d 100644 --- a/db/marshal/type_parser.cc +++ b/db/marshal/type_parser.cc @@ -16,6 +16,7 @@ #include #include +#include "cql3/cql3_type.hh" #include "types/user.hh" #include "types/map.hh" #include "types/list.hh" @@ -113,7 +114,7 @@ std::vector type_parser::get_type_parameters(bool multicell) throw parse_exception(_str, _idx, "unexpected end of string"); } -std::tuple type_parser::get_vector_parameters() +std::tuple type_parser::get_vector_parameters() { if (is_eos() || _str[_idx] != '(') { throw std::logic_error("internal error"); @@ -128,7 +129,7 @@ std::tuple type_parser::get_vector_parameters() } data_type type = do_parse(true); - size_t size = 0; + vector_dimension_t size = 0; if (_str[_idx] == ',') { ++_idx; skip_blank(); @@ -142,7 +143,20 @@ std::tuple type_parser::get_vector_parameters() throw parse_exception(_str, _idx, "expected digit or ')'"); } - size = std::stoul(_str.substr(i, _idx - i)); + unsigned long parsed_size; + try { + parsed_size = std::stoul(_str.substr(i, _idx - i)); + } catch (const std::exception& e) { + throw parse_exception(_str, i, format("Invalid vector dimension: {}", e.what())); + } + static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t)); + if (parsed_size == 0) { + throw parse_exception(_str, _idx, "Vectors must have a dimension greater than 0"); + } + if (parsed_size > cql3::cql3_type::MAX_VECTOR_DIMENSION) { + throw parse_exception(_str, _idx, format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION)); + } + size = static_cast(parsed_size); ++_idx; // skipping ')' return std::make_tuple(type, size); diff --git a/db/marshal/type_parser.hh b/db/marshal/type_parser.hh index d1e6b2b69c..be82dd47af 100644 --- a/db/marshal/type_parser.hh +++ b/db/marshal/type_parser.hh @@ -97,7 +97,7 @@ public: } #endif std::vector get_type_parameters(bool multicell=true); - std::tuple get_vector_parameters(); + std::tuple get_vector_parameters(); std::tuple, std::vector> get_user_type_parameters(); data_type do_parse(bool multicell = true); diff --git a/lang/lua.cc b/lang/lua.cc index 991050f811..e347652169 100644 --- a/lang/lua.cc +++ b/lang/lua.cc @@ -743,7 +743,7 @@ struct from_lua_visitor { } const data_type& elements_type = t.get_elements_type(); - size_t num_elements = t.get_dimension(); + vector_dimension_t num_elements = t.get_dimension(); using table_pair = std::pair; std::vector pairs; diff --git a/test/cqlpy/test_type_vector.py b/test/cqlpy/test_type_vector.py index 5631a3f09f..edbdae9ea9 100644 --- a/test/cqlpy/test_type_vector.py +++ b/test/cqlpy/test_type_vector.py @@ -4,7 +4,10 @@ ############################################################################# # Tests involving the "vector" column type. -from .util import new_test_table + +import pytest +from .util import new_test_table, is_scylla +from cassandra.protocol import InvalidRequest, SyntaxException def test_vector_of_set_using_arguments_binding(cql, test_keyspace): @@ -26,3 +29,41 @@ def test_vector_of_set_using_arguments_binding(cql, test_keyspace): assert row is not None assert row.v == value_to_insert + + +# This is an artificial limit set to the value matching the OpenSearch implementation of Vector Search. +# Cassandra itself does not have a hard limit on the dimension of vectors, except mentioning 2^13 as a recommended maximum in the documentation. +# Instead, we test with Java's Integer.MAX_VALUE (2^31 - 1). +@pytest.fixture(scope="module") +def MAX_VECTOR_DIMENSION(cql): + return 16000 if is_scylla(cql) else 2**31 - 1 + + +def test_vector_dimension_upper_bound_is_allowed(cql, test_keyspace, MAX_VECTOR_DIMENSION): + with new_test_table(cql, test_keyspace, f"pk int primary key, v vector"): + pass + + +def test_vector_dimension_above_upper_bound_is_rejected(cql, test_keyspace, MAX_VECTOR_DIMENSION): + with pytest.raises(InvalidRequest, match=f"Vectors must have a dimension less than or equal to {MAX_VECTOR_DIMENSION}") if is_scylla(cql) else pytest.raises(SyntaxException, match="NumberFormatException"): + with new_test_table(cql, test_keyspace, f"pk int primary key, v vector"): + pass + + +def test_vector_dimension_zero_is_rejected(cql, test_keyspace): + with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"): + with new_test_table(cql, test_keyspace, "pk int primary key, v vector"): + pass + + +def test_vector_dimension_negative_is_rejected(cql, test_keyspace): + with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"): + with new_test_table(cql, test_keyspace, "pk int primary key, v vector"): + pass + + +@pytest.mark.parametrize("invalid_dimension", ["dog", "123x", "1.5"]) +def test_vector_dimension_non_integer_is_rejected(cql, test_keyspace, invalid_dimension): + with pytest.raises(SyntaxException): + with new_test_table(cql, test_keyspace, f"pk int primary key, v vector"): + pass diff --git a/test/lib/expr_test_utils.cc b/test/lib/expr_test_utils.cc index 4773136ea3..ddeae3dbb4 100644 --- a/test/lib/expr_test_utils.cc +++ b/test/lib/expr_test_utils.cc @@ -392,7 +392,7 @@ tuple_constructor make_tuple_constructor(std::vector elements, std:: .type = tuple_type_impl::get_instance(std::move(element_types))}; } -collection_constructor make_vector_constructor(std::vector elements, data_type elements_type, size_t dimension) { +collection_constructor make_vector_constructor(std::vector elements, data_type elements_type, vector_dimension_t dimension) { return collection_constructor{.style = collection_constructor::style_type::vector, .elements = std::move(elements), .type = vector_type_impl::get_instance(elements_type, dimension)}; diff --git a/test/lib/expr_test_utils.hh b/test/lib/expr_test_utils.hh index b65136663a..4f48748e98 100644 --- a/test/lib/expr_test_utils.hh +++ b/test/lib/expr_test_utils.hh @@ -116,7 +116,7 @@ collection_constructor make_map_constructor(const std::vector elements, std::vector element_types); -collection_constructor make_vector_constructor(std::vector elements, data_type elements_type, size_t dimension); +collection_constructor make_vector_constructor(std::vector elements, data_type elements_type, vector_dimension_t dimension); usertype_constructor make_usertype_constructor(std::vector> field_values); ::lw_shared_ptr make_receiver(data_type receiver_type, sstring name = "receiver_name"); diff --git a/types/types.cc b/types/types.cc index 174286f599..5a233b365d 100644 --- a/types/types.cc +++ b/types/types.cc @@ -1643,22 +1643,22 @@ static void validate_aux(const tuple_type_impl& t, View v) { } } -sstring vector_type_impl::make_name(data_type type, size_t dimension) { +sstring vector_type_impl::make_name(data_type type, vector_dimension_t dimension) { // To keep format compatibility with Origin we never wrap // vector name into // "org.apache.cassandra.db.marshal.FrozenType(...)". return seastar::format("org.apache.cassandra.db.marshal.VectorType({}, {})", type->name(), dimension); } -vector_type_impl::vector_type_impl(data_type elements, size_t dimension) +vector_type_impl::vector_type_impl(data_type elements, vector_dimension_t dimension) : concrete_type(kind::vector, make_name(elements, dimension), - elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value()*dimension):std::nullopt), + elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value() * dimension) : std::nullopt), _elements_type(elements), _dimension(dimension) { _contains_set_or_map = _elements_type->contains_set_or_map(); } shared_ptr -vector_type_impl::get_instance(data_type elements, size_t dimension) { +vector_type_impl::get_instance(data_type elements, vector_dimension_t dimension) { return intern::get_instance(elements, dimension); } @@ -1681,7 +1681,7 @@ static void serialize_vector(const vector_type_impl& type, const vector_type_imp } std::strong_ordering -vector_type_impl::compare_vectors(data_type elements, size_t dimension, managed_bytes_view o1, managed_bytes_view o2) { +vector_type_impl::compare_vectors(data_type elements, vector_dimension_t dimension, managed_bytes_view o1, managed_bytes_view o2) { if (o1.empty()) { return o2.empty() ? std::strong_ordering::equal : std::strong_ordering::less; } else if (o2.empty()) { diff --git a/types/types.hh b/types/types.hh index 80a7129a27..f13a7767de 100644 --- a/types/types.hh +++ b/types/types.hh @@ -176,6 +176,8 @@ struct timeuuid_native_type { using data_type = shared_ptr; +using vector_dimension_t = uint32_t; + template const T& value_cast(const data_value& value); diff --git a/types/vector.hh b/types/vector.hh index ed102f3347..d54ff400bf 100644 --- a/types/vector.hh +++ b/types/vector.hh @@ -14,20 +14,20 @@ #include "vint-serialization.hh" class vector_type_impl : public concrete_type> { - using intern = type_interning_helper; + using intern = type_interning_helper; protected: data_type _elements_type; - size_t _dimension; + vector_dimension_t _dimension; public: - vector_type_impl(data_type elements_type, size_t dimension); - static shared_ptr get_instance(data_type type, size_t dimension); + vector_type_impl(data_type elements_type, vector_dimension_t dimension); + static shared_ptr get_instance(data_type type, vector_dimension_t dimension); data_type get_elements_type() const { return _elements_type; } - size_t get_dimension() const { + vector_dimension_t get_dimension() const { return _dimension; } - static std::strong_ordering compare_vectors(data_type elements_comparator, size_t dimension, + static std::strong_ordering compare_vectors(data_type elements_comparator, vector_dimension_t dimension, managed_bytes_view o1, managed_bytes_view o2); std::vector split_fragmented(FragmentedView auto v) const { @@ -94,7 +94,7 @@ public: return ret; } private: - static sstring make_name(data_type type, size_t dimension); + static sstring make_name(data_type type, vector_dimension_t dimension); };