cql: vector: fix vector dimension type

Switch vector dimension handling to fixed-width `uint32_t` type,
update parsing/validation, and add boundary tests.

The dimension is parsed as `unsigned long` at first which is guaranteed
to be **at least** 32-bit long, which is safe to downcast to `uint32_t`.

Move `MAX_VECTOR_DIMENSION` from `cql3_type::raw_vector` to `cql3_type`
to ensure public visibility for checks outside the class.

Add tests to verify the type boundaries.

Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-223

Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
Co-authored-by: Dawid Pawlik <dawid.pawlik@scylladb.com>

Closes scylladb/scylladb#28762
This commit is contained in:
Yaniv Michael Kaul
2026-01-19 09:44:02 +02:00
committed by Nadav Har'El
parent 4a60ee28a2
commit ead9961783
16 changed files with 106 additions and 35 deletions

View File

@@ -2071,7 +2071,21 @@ vector_type returns [shared_ptr<cql3::cql3_type::raw> pt]
{
if ($d.text[0] == '-')
throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0");
$pt = cql3::cql3_type::raw::vector(t, std::stoul($d.text));
unsigned long parsed_dimension;
try {
parsed_dimension = std::stoul($d.text);
} catch (const std::exception& e) {
throw exceptions::invalid_request_exception(format("Invalid vector dimension: {}", $d.text));
}
static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t));
if (parsed_dimension == 0) {
throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0");
}
if (parsed_dimension > cql3::cql3_type::MAX_VECTOR_DIMENSION) {
throw exceptions::invalid_request_exception(
format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION));
}
$pt = cql3::cql3_type::raw::vector(t, static_cast<vector_dimension_t>(parsed_dimension));
}
;

View File

@@ -27,7 +27,7 @@ public:
struct vector_test_result {
test_result result;
std::optional<size_t> dimension_opt;
std::optional<vector_dimension_t> dimension_opt;
};
static bool is_assignable(test_result tr) {

View File

@@ -307,17 +307,14 @@ public:
class cql3_type::raw_vector : public raw {
shared_ptr<raw> _type;
size_t _dimension;
// This limitation is acquired from the maximum number of dimensions in OpenSearch.
static constexpr size_t MAX_VECTOR_DIMENSION = 16000;
vector_dimension_t _dimension;
virtual sstring to_string() const override {
return seastar::format("vector<{}, {}>", _type, _dimension);
}
public:
raw_vector(shared_ptr<raw> type, size_t dimension)
raw_vector(shared_ptr<raw> type, vector_dimension_t dimension)
: _type(std::move(type)), _dimension(dimension) {
}
@@ -417,7 +414,7 @@ cql3_type::raw::tuple(std::vector<shared_ptr<raw>> ts) {
}
shared_ptr<cql3_type::raw>
cql3_type::raw::vector(shared_ptr<raw> t, size_t dimension) {
cql3_type::raw::vector(shared_ptr<raw> t, vector_dimension_t dimension) {
return ::make_shared<raw_vector>(std::move(t), dimension);
}

View File

@@ -39,6 +39,9 @@ public:
data_type get_type() const { return _type; }
const sstring& to_string() const { return _type->cql3_type_name(); }
// This limitation is acquired from the maximum number of dimensions in OpenSearch.
static constexpr vector_dimension_t MAX_VECTOR_DIMENSION = 16000;
// For UserTypes, we need to know the current keyspace to resolve the
// actual type used, so Raw is a "not yet prepared" CQL3Type.
class raw {
@@ -64,7 +67,7 @@ public:
static shared_ptr<raw> list(shared_ptr<raw> t);
static shared_ptr<raw> set(shared_ptr<raw> t);
static shared_ptr<raw> tuple(std::vector<shared_ptr<raw>> ts);
static shared_ptr<raw> vector(shared_ptr<raw> t, size_t dimension);
static shared_ptr<raw> vector(shared_ptr<raw> t, vector_dimension_t dimension);
static shared_ptr<raw> frozen(shared_ptr<raw> t);
friend sstring format_as(const raw& r) {
return r.to_string();

View File

@@ -502,8 +502,8 @@ vector_validate_assignable_to(const collection_constructor& c, data_dictionary::
throw exceptions::invalid_request_exception(format("Invalid vector type literal for {} of type {}", *receiver.name, receiver.type->as_cql3_type()));
}
size_t expected_size = vt->get_dimension();
if (!expected_size) {
vector_dimension_t expected_size = vt->get_dimension();
if (expected_size == 0) {
throw exceptions::invalid_request_exception(format("Invalid vector type literal for {}: type {} expects at least one element",
*receiver.name, receiver.type->as_cql3_type()));
}

View File

@@ -18,7 +18,7 @@ namespace functions {
namespace detail {
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension) {
std::vector<float> extract_float_vector(const bytes_opt& param, vector_dimension_t dimension) {
if (!param) {
throw exceptions::invalid_request_exception("Cannot extract float vector from null parameter");
}
@@ -156,7 +156,7 @@ std::vector<data_type> retrieve_vector_arg_types(const function_name& name, cons
}
}
size_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt;
vector_dimension_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt;
auto type = vector_type_impl::get_instance(float_type, dimension);
return {type, type};
}
@@ -170,7 +170,7 @@ bytes_opt vector_similarity_fct::execute(std::span<const bytes_opt> parameters)
// Extract dimension from the vector type
const auto& type = static_cast<const vector_type_impl&>(*arg_types()[0]);
size_t dimension = type.get_dimension();
vector_dimension_t dimension = type.get_dimension();
// Optimized path: extract floats directly from bytes, bypassing data_value overhead
std::vector<float> v1 = detail::extract_float_vector(parameters[0], dimension);

View File

@@ -39,7 +39,7 @@ namespace detail {
// Extract float vector directly from serialized bytes, bypassing data_value overhead.
// This is an internal API exposed for testing purposes.
// Vector<float, N> wire format: N floats as big-endian uint32_t values, 4 bytes each.
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension);
std::vector<float> extract_float_vector(const bytes_opt& param, vector_dimension_t dimension);
} // namespace detail

View File

@@ -16,6 +16,7 @@
#include <string>
#include <tuple>
#include "cql3/cql3_type.hh"
#include "types/user.hh"
#include "types/map.hh"
#include "types/list.hh"
@@ -113,7 +114,7 @@ std::vector<data_type> type_parser::get_type_parameters(bool multicell)
throw parse_exception(_str, _idx, "unexpected end of string");
}
std::tuple<data_type, size_t> type_parser::get_vector_parameters()
std::tuple<data_type, vector_dimension_t> type_parser::get_vector_parameters()
{
if (is_eos() || _str[_idx] != '(') {
throw std::logic_error("internal error");
@@ -128,7 +129,7 @@ std::tuple<data_type, size_t> type_parser::get_vector_parameters()
}
data_type type = do_parse(true);
size_t size = 0;
vector_dimension_t size = 0;
if (_str[_idx] == ',') {
++_idx;
skip_blank();
@@ -142,7 +143,20 @@ std::tuple<data_type, size_t> type_parser::get_vector_parameters()
throw parse_exception(_str, _idx, "expected digit or ')'");
}
size = std::stoul(_str.substr(i, _idx - i));
unsigned long parsed_size;
try {
parsed_size = std::stoul(_str.substr(i, _idx - i));
} catch (const std::exception& e) {
throw parse_exception(_str, i, format("Invalid vector dimension: {}", e.what()));
}
static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t));
if (parsed_size == 0) {
throw parse_exception(_str, _idx, "Vectors must have a dimension greater than 0");
}
if (parsed_size > cql3::cql3_type::MAX_VECTOR_DIMENSION) {
throw parse_exception(_str, _idx, format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION));
}
size = static_cast<vector_dimension_t>(parsed_size);
++_idx; // skipping ')'
return std::make_tuple(type, size);

View File

@@ -97,7 +97,7 @@ public:
}
#endif
std::vector<data_type> get_type_parameters(bool multicell=true);
std::tuple<data_type, size_t> get_vector_parameters();
std::tuple<data_type, vector_dimension_t> get_vector_parameters();
std::tuple<sstring, bytes, std::vector<bytes>, std::vector<data_type>> get_user_type_parameters();
data_type do_parse(bool multicell = true);

View File

@@ -743,7 +743,7 @@ struct from_lua_visitor {
}
const data_type& elements_type = t.get_elements_type();
size_t num_elements = t.get_dimension();
vector_dimension_t num_elements = t.get_dimension();
using table_pair = std::pair<utils::multiprecision_int, data_value>;
std::vector<table_pair> pairs;

View File

@@ -4,7 +4,10 @@
#############################################################################
# Tests involving the "vector" column type.
from .util import new_test_table
import pytest
from .util import new_test_table, is_scylla
from cassandra.protocol import InvalidRequest, SyntaxException
def test_vector_of_set_using_arguments_binding(cql, test_keyspace):
@@ -26,3 +29,41 @@ def test_vector_of_set_using_arguments_binding(cql, test_keyspace):
assert row is not None
assert row.v == value_to_insert
# This is an artificial limit set to the value matching the OpenSearch implementation of Vector Search.
# Cassandra itself does not have a hard limit on the dimension of vectors, except mentioning 2^13 as a recommended maximum in the documentation.
# Instead, we test with Java's Integer.MAX_VALUE (2^31 - 1).
@pytest.fixture(scope="module")
def MAX_VECTOR_DIMENSION(cql):
return 16000 if is_scylla(cql) else 2**31 - 1
def test_vector_dimension_upper_bound_is_allowed(cql, test_keyspace, MAX_VECTOR_DIMENSION):
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {MAX_VECTOR_DIMENSION}>"):
pass
def test_vector_dimension_above_upper_bound_is_rejected(cql, test_keyspace, MAX_VECTOR_DIMENSION):
with pytest.raises(InvalidRequest, match=f"Vectors must have a dimension less than or equal to {MAX_VECTOR_DIMENSION}") if is_scylla(cql) else pytest.raises(SyntaxException, match="NumberFormatException"):
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {MAX_VECTOR_DIMENSION + 1}>"):
pass
def test_vector_dimension_zero_is_rejected(cql, test_keyspace):
with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"):
with new_test_table(cql, test_keyspace, "pk int primary key, v vector<float, 0>"):
pass
def test_vector_dimension_negative_is_rejected(cql, test_keyspace):
with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"):
with new_test_table(cql, test_keyspace, "pk int primary key, v vector<float, -18>"):
pass
@pytest.mark.parametrize("invalid_dimension", ["dog", "123x", "1.5"])
def test_vector_dimension_non_integer_is_rejected(cql, test_keyspace, invalid_dimension):
with pytest.raises(SyntaxException):
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {invalid_dimension}>"):
pass

View File

@@ -392,7 +392,7 @@ tuple_constructor make_tuple_constructor(std::vector<expression> elements, std::
.type = tuple_type_impl::get_instance(std::move(element_types))};
}
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, size_t dimension) {
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, vector_dimension_t dimension) {
return collection_constructor{.style = collection_constructor::style_type::vector,
.elements = std::move(elements),
.type = vector_type_impl::get_instance(elements_type, dimension)};

View File

@@ -116,7 +116,7 @@ collection_constructor make_map_constructor(const std::vector<std::pair<expressi
data_type key_type,
data_type element_type);
tuple_constructor make_tuple_constructor(std::vector<expression> elements, std::vector<data_type> element_types);
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, size_t dimension);
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, vector_dimension_t dimension);
usertype_constructor make_usertype_constructor(std::vector<std::pair<std::string_view, constant>> field_values);
::lw_shared_ptr<column_specification> make_receiver(data_type receiver_type, sstring name = "receiver_name");

View File

@@ -1643,22 +1643,22 @@ static void validate_aux(const tuple_type_impl& t, View v) {
}
}
sstring vector_type_impl::make_name(data_type type, size_t dimension) {
sstring vector_type_impl::make_name(data_type type, vector_dimension_t dimension) {
// To keep format compatibility with Origin we never wrap
// vector name into
// "org.apache.cassandra.db.marshal.FrozenType(...)".
return seastar::format("org.apache.cassandra.db.marshal.VectorType({}, {})", type->name(), dimension);
}
vector_type_impl::vector_type_impl(data_type elements, size_t dimension)
vector_type_impl::vector_type_impl(data_type elements, vector_dimension_t dimension)
: concrete_type(kind::vector, make_name(elements, dimension),
elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value()*dimension):std::nullopt),
elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value() * dimension) : std::nullopt),
_elements_type(elements), _dimension(dimension) {
_contains_set_or_map = _elements_type->contains_set_or_map();
}
shared_ptr<const vector_type_impl>
vector_type_impl::get_instance(data_type elements, size_t dimension) {
vector_type_impl::get_instance(data_type elements, vector_dimension_t dimension) {
return intern::get_instance(elements, dimension);
}
@@ -1681,7 +1681,7 @@ static void serialize_vector(const vector_type_impl& type, const vector_type_imp
}
std::strong_ordering
vector_type_impl::compare_vectors(data_type elements, size_t dimension, managed_bytes_view o1, managed_bytes_view o2) {
vector_type_impl::compare_vectors(data_type elements, vector_dimension_t dimension, managed_bytes_view o1, managed_bytes_view o2) {
if (o1.empty()) {
return o2.empty() ? std::strong_ordering::equal : std::strong_ordering::less;
} else if (o2.empty()) {

View File

@@ -176,6 +176,8 @@ struct timeuuid_native_type {
using data_type = shared_ptr<const abstract_type>;
using vector_dimension_t = uint32_t;
template <typename T>
const T& value_cast(const data_value& value);

View File

@@ -14,20 +14,20 @@
#include "vint-serialization.hh"
class vector_type_impl : public concrete_type<std::vector<data_value>> {
using intern = type_interning_helper<vector_type_impl, data_type, size_t>;
using intern = type_interning_helper<vector_type_impl, data_type, vector_dimension_t>;
protected:
data_type _elements_type;
size_t _dimension;
vector_dimension_t _dimension;
public:
vector_type_impl(data_type elements_type, size_t dimension);
static shared_ptr<const vector_type_impl> get_instance(data_type type, size_t dimension);
vector_type_impl(data_type elements_type, vector_dimension_t dimension);
static shared_ptr<const vector_type_impl> get_instance(data_type type, vector_dimension_t dimension);
data_type get_elements_type() const {
return _elements_type;
}
size_t get_dimension() const {
vector_dimension_t get_dimension() const {
return _dimension;
}
static std::strong_ordering compare_vectors(data_type elements_comparator, size_t dimension,
static std::strong_ordering compare_vectors(data_type elements_comparator, vector_dimension_t dimension,
managed_bytes_view o1, managed_bytes_view o2);
std::vector<managed_bytes> split_fragmented(FragmentedView auto v) const {
@@ -94,7 +94,7 @@ public:
return ret;
}
private:
static sstring make_name(data_type type, size_t dimension);
static sstring make_name(data_type type, vector_dimension_t dimension);
};