mirror of
https://github.com/scylladb/scylladb.git
synced 2026-05-12 19:02:12 +00:00
cql: vector: fix vector dimension type
Switch vector dimension handling to fixed-width `uint32_t` type, update parsing/validation, and add boundary tests. The dimension is parsed as `unsigned long` at first which is guaranteed to be **at least** 32-bit long, which is safe to downcast to `uint32_t`. Move `MAX_VECTOR_DIMENSION` from `cql3_type::raw_vector` to `cql3_type` to ensure public visibility for checks outside the class. Add tests to verify the type boundaries. Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-223 Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com> Co-authored-by: Dawid Pawlik <dawid.pawlik@scylladb.com> Closes scylladb/scylladb#28762
This commit is contained in:
committed by
Nadav Har'El
parent
4a60ee28a2
commit
ead9961783
16
cql3/Cql.g
16
cql3/Cql.g
@@ -2071,7 +2071,21 @@ vector_type returns [shared_ptr<cql3::cql3_type::raw> pt]
|
||||
{
|
||||
if ($d.text[0] == '-')
|
||||
throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0");
|
||||
$pt = cql3::cql3_type::raw::vector(t, std::stoul($d.text));
|
||||
unsigned long parsed_dimension;
|
||||
try {
|
||||
parsed_dimension = std::stoul($d.text);
|
||||
} catch (const std::exception& e) {
|
||||
throw exceptions::invalid_request_exception(format("Invalid vector dimension: {}", $d.text));
|
||||
}
|
||||
static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t));
|
||||
if (parsed_dimension == 0) {
|
||||
throw exceptions::invalid_request_exception("Vectors must have a dimension greater than 0");
|
||||
}
|
||||
if (parsed_dimension > cql3::cql3_type::MAX_VECTOR_DIMENSION) {
|
||||
throw exceptions::invalid_request_exception(
|
||||
format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION));
|
||||
}
|
||||
$pt = cql3::cql3_type::raw::vector(t, static_cast<vector_dimension_t>(parsed_dimension));
|
||||
}
|
||||
;
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ public:
|
||||
|
||||
struct vector_test_result {
|
||||
test_result result;
|
||||
std::optional<size_t> dimension_opt;
|
||||
std::optional<vector_dimension_t> dimension_opt;
|
||||
};
|
||||
|
||||
static bool is_assignable(test_result tr) {
|
||||
|
||||
@@ -307,17 +307,14 @@ public:
|
||||
|
||||
class cql3_type::raw_vector : public raw {
|
||||
shared_ptr<raw> _type;
|
||||
size_t _dimension;
|
||||
|
||||
// This limitation is acquired from the maximum number of dimensions in OpenSearch.
|
||||
static constexpr size_t MAX_VECTOR_DIMENSION = 16000;
|
||||
vector_dimension_t _dimension;
|
||||
|
||||
virtual sstring to_string() const override {
|
||||
return seastar::format("vector<{}, {}>", _type, _dimension);
|
||||
}
|
||||
|
||||
public:
|
||||
raw_vector(shared_ptr<raw> type, size_t dimension)
|
||||
raw_vector(shared_ptr<raw> type, vector_dimension_t dimension)
|
||||
: _type(std::move(type)), _dimension(dimension) {
|
||||
}
|
||||
|
||||
@@ -417,7 +414,7 @@ cql3_type::raw::tuple(std::vector<shared_ptr<raw>> ts) {
|
||||
}
|
||||
|
||||
shared_ptr<cql3_type::raw>
|
||||
cql3_type::raw::vector(shared_ptr<raw> t, size_t dimension) {
|
||||
cql3_type::raw::vector(shared_ptr<raw> t, vector_dimension_t dimension) {
|
||||
return ::make_shared<raw_vector>(std::move(t), dimension);
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,9 @@ public:
|
||||
data_type get_type() const { return _type; }
|
||||
const sstring& to_string() const { return _type->cql3_type_name(); }
|
||||
|
||||
// This limitation is acquired from the maximum number of dimensions in OpenSearch.
|
||||
static constexpr vector_dimension_t MAX_VECTOR_DIMENSION = 16000;
|
||||
|
||||
// For UserTypes, we need to know the current keyspace to resolve the
|
||||
// actual type used, so Raw is a "not yet prepared" CQL3Type.
|
||||
class raw {
|
||||
@@ -64,7 +67,7 @@ public:
|
||||
static shared_ptr<raw> list(shared_ptr<raw> t);
|
||||
static shared_ptr<raw> set(shared_ptr<raw> t);
|
||||
static shared_ptr<raw> tuple(std::vector<shared_ptr<raw>> ts);
|
||||
static shared_ptr<raw> vector(shared_ptr<raw> t, size_t dimension);
|
||||
static shared_ptr<raw> vector(shared_ptr<raw> t, vector_dimension_t dimension);
|
||||
static shared_ptr<raw> frozen(shared_ptr<raw> t);
|
||||
friend sstring format_as(const raw& r) {
|
||||
return r.to_string();
|
||||
|
||||
@@ -502,8 +502,8 @@ vector_validate_assignable_to(const collection_constructor& c, data_dictionary::
|
||||
throw exceptions::invalid_request_exception(format("Invalid vector type literal for {} of type {}", *receiver.name, receiver.type->as_cql3_type()));
|
||||
}
|
||||
|
||||
size_t expected_size = vt->get_dimension();
|
||||
if (!expected_size) {
|
||||
vector_dimension_t expected_size = vt->get_dimension();
|
||||
if (expected_size == 0) {
|
||||
throw exceptions::invalid_request_exception(format("Invalid vector type literal for {}: type {} expects at least one element",
|
||||
*receiver.name, receiver.type->as_cql3_type()));
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace functions {
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension) {
|
||||
std::vector<float> extract_float_vector(const bytes_opt& param, vector_dimension_t dimension) {
|
||||
if (!param) {
|
||||
throw exceptions::invalid_request_exception("Cannot extract float vector from null parameter");
|
||||
}
|
||||
@@ -156,7 +156,7 @@ std::vector<data_type> retrieve_vector_arg_types(const function_name& name, cons
|
||||
}
|
||||
}
|
||||
|
||||
size_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt;
|
||||
vector_dimension_t dimension = first_dim_opt ? *first_dim_opt : *second_dim_opt;
|
||||
auto type = vector_type_impl::get_instance(float_type, dimension);
|
||||
return {type, type};
|
||||
}
|
||||
@@ -170,7 +170,7 @@ bytes_opt vector_similarity_fct::execute(std::span<const bytes_opt> parameters)
|
||||
|
||||
// Extract dimension from the vector type
|
||||
const auto& type = static_cast<const vector_type_impl&>(*arg_types()[0]);
|
||||
size_t dimension = type.get_dimension();
|
||||
vector_dimension_t dimension = type.get_dimension();
|
||||
|
||||
// Optimized path: extract floats directly from bytes, bypassing data_value overhead
|
||||
std::vector<float> v1 = detail::extract_float_vector(parameters[0], dimension);
|
||||
|
||||
@@ -39,7 +39,7 @@ namespace detail {
|
||||
// Extract float vector directly from serialized bytes, bypassing data_value overhead.
|
||||
// This is an internal API exposed for testing purposes.
|
||||
// Vector<float, N> wire format: N floats as big-endian uint32_t values, 4 bytes each.
|
||||
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension);
|
||||
std::vector<float> extract_float_vector(const bytes_opt& param, vector_dimension_t dimension);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "cql3/cql3_type.hh"
|
||||
#include "types/user.hh"
|
||||
#include "types/map.hh"
|
||||
#include "types/list.hh"
|
||||
@@ -113,7 +114,7 @@ std::vector<data_type> type_parser::get_type_parameters(bool multicell)
|
||||
throw parse_exception(_str, _idx, "unexpected end of string");
|
||||
}
|
||||
|
||||
std::tuple<data_type, size_t> type_parser::get_vector_parameters()
|
||||
std::tuple<data_type, vector_dimension_t> type_parser::get_vector_parameters()
|
||||
{
|
||||
if (is_eos() || _str[_idx] != '(') {
|
||||
throw std::logic_error("internal error");
|
||||
@@ -128,7 +129,7 @@ std::tuple<data_type, size_t> type_parser::get_vector_parameters()
|
||||
}
|
||||
|
||||
data_type type = do_parse(true);
|
||||
size_t size = 0;
|
||||
vector_dimension_t size = 0;
|
||||
if (_str[_idx] == ',') {
|
||||
++_idx;
|
||||
skip_blank();
|
||||
@@ -142,7 +143,20 @@ std::tuple<data_type, size_t> type_parser::get_vector_parameters()
|
||||
throw parse_exception(_str, _idx, "expected digit or ')'");
|
||||
}
|
||||
|
||||
size = std::stoul(_str.substr(i, _idx - i));
|
||||
unsigned long parsed_size;
|
||||
try {
|
||||
parsed_size = std::stoul(_str.substr(i, _idx - i));
|
||||
} catch (const std::exception& e) {
|
||||
throw parse_exception(_str, i, format("Invalid vector dimension: {}", e.what()));
|
||||
}
|
||||
static_assert(sizeof(unsigned long) >= sizeof(vector_dimension_t));
|
||||
if (parsed_size == 0) {
|
||||
throw parse_exception(_str, _idx, "Vectors must have a dimension greater than 0");
|
||||
}
|
||||
if (parsed_size > cql3::cql3_type::MAX_VECTOR_DIMENSION) {
|
||||
throw parse_exception(_str, _idx, format("Vectors must have a dimension less than or equal to {}", cql3::cql3_type::MAX_VECTOR_DIMENSION));
|
||||
}
|
||||
size = static_cast<vector_dimension_t>(parsed_size);
|
||||
|
||||
++_idx; // skipping ')'
|
||||
return std::make_tuple(type, size);
|
||||
|
||||
@@ -97,7 +97,7 @@ public:
|
||||
}
|
||||
#endif
|
||||
std::vector<data_type> get_type_parameters(bool multicell=true);
|
||||
std::tuple<data_type, size_t> get_vector_parameters();
|
||||
std::tuple<data_type, vector_dimension_t> get_vector_parameters();
|
||||
std::tuple<sstring, bytes, std::vector<bytes>, std::vector<data_type>> get_user_type_parameters();
|
||||
data_type do_parse(bool multicell = true);
|
||||
|
||||
|
||||
@@ -743,7 +743,7 @@ struct from_lua_visitor {
|
||||
}
|
||||
|
||||
const data_type& elements_type = t.get_elements_type();
|
||||
size_t num_elements = t.get_dimension();
|
||||
vector_dimension_t num_elements = t.get_dimension();
|
||||
|
||||
using table_pair = std::pair<utils::multiprecision_int, data_value>;
|
||||
std::vector<table_pair> pairs;
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
|
||||
#############################################################################
|
||||
# Tests involving the "vector" column type.
|
||||
from .util import new_test_table
|
||||
|
||||
import pytest
|
||||
from .util import new_test_table, is_scylla
|
||||
from cassandra.protocol import InvalidRequest, SyntaxException
|
||||
|
||||
|
||||
def test_vector_of_set_using_arguments_binding(cql, test_keyspace):
|
||||
@@ -26,3 +29,41 @@ def test_vector_of_set_using_arguments_binding(cql, test_keyspace):
|
||||
|
||||
assert row is not None
|
||||
assert row.v == value_to_insert
|
||||
|
||||
|
||||
# This is an artificial limit set to the value matching the OpenSearch implementation of Vector Search.
|
||||
# Cassandra itself does not have a hard limit on the dimension of vectors, except mentioning 2^13 as a recommended maximum in the documentation.
|
||||
# Instead, we test with Java's Integer.MAX_VALUE (2^31 - 1).
|
||||
@pytest.fixture(scope="module")
|
||||
def MAX_VECTOR_DIMENSION(cql):
|
||||
return 16000 if is_scylla(cql) else 2**31 - 1
|
||||
|
||||
|
||||
def test_vector_dimension_upper_bound_is_allowed(cql, test_keyspace, MAX_VECTOR_DIMENSION):
|
||||
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {MAX_VECTOR_DIMENSION}>"):
|
||||
pass
|
||||
|
||||
|
||||
def test_vector_dimension_above_upper_bound_is_rejected(cql, test_keyspace, MAX_VECTOR_DIMENSION):
|
||||
with pytest.raises(InvalidRequest, match=f"Vectors must have a dimension less than or equal to {MAX_VECTOR_DIMENSION}") if is_scylla(cql) else pytest.raises(SyntaxException, match="NumberFormatException"):
|
||||
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {MAX_VECTOR_DIMENSION + 1}>"):
|
||||
pass
|
||||
|
||||
|
||||
def test_vector_dimension_zero_is_rejected(cql, test_keyspace):
|
||||
with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"):
|
||||
with new_test_table(cql, test_keyspace, "pk int primary key, v vector<float, 0>"):
|
||||
pass
|
||||
|
||||
|
||||
def test_vector_dimension_negative_is_rejected(cql, test_keyspace):
|
||||
with pytest.raises(InvalidRequest, match="Vectors must have a dimension greater than 0" if is_scylla(cql) else "vectors may only have positive dimensions"):
|
||||
with new_test_table(cql, test_keyspace, "pk int primary key, v vector<float, -18>"):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_dimension", ["dog", "123x", "1.5"])
|
||||
def test_vector_dimension_non_integer_is_rejected(cql, test_keyspace, invalid_dimension):
|
||||
with pytest.raises(SyntaxException):
|
||||
with new_test_table(cql, test_keyspace, f"pk int primary key, v vector<float, {invalid_dimension}>"):
|
||||
pass
|
||||
|
||||
@@ -392,7 +392,7 @@ tuple_constructor make_tuple_constructor(std::vector<expression> elements, std::
|
||||
.type = tuple_type_impl::get_instance(std::move(element_types))};
|
||||
}
|
||||
|
||||
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, size_t dimension) {
|
||||
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, vector_dimension_t dimension) {
|
||||
return collection_constructor{.style = collection_constructor::style_type::vector,
|
||||
.elements = std::move(elements),
|
||||
.type = vector_type_impl::get_instance(elements_type, dimension)};
|
||||
|
||||
@@ -116,7 +116,7 @@ collection_constructor make_map_constructor(const std::vector<std::pair<expressi
|
||||
data_type key_type,
|
||||
data_type element_type);
|
||||
tuple_constructor make_tuple_constructor(std::vector<expression> elements, std::vector<data_type> element_types);
|
||||
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, size_t dimension);
|
||||
collection_constructor make_vector_constructor(std::vector<expression> elements, data_type elements_type, vector_dimension_t dimension);
|
||||
usertype_constructor make_usertype_constructor(std::vector<std::pair<std::string_view, constant>> field_values);
|
||||
|
||||
::lw_shared_ptr<column_specification> make_receiver(data_type receiver_type, sstring name = "receiver_name");
|
||||
|
||||
@@ -1643,22 +1643,22 @@ static void validate_aux(const tuple_type_impl& t, View v) {
|
||||
}
|
||||
}
|
||||
|
||||
sstring vector_type_impl::make_name(data_type type, size_t dimension) {
|
||||
sstring vector_type_impl::make_name(data_type type, vector_dimension_t dimension) {
|
||||
// To keep format compatibility with Origin we never wrap
|
||||
// vector name into
|
||||
// "org.apache.cassandra.db.marshal.FrozenType(...)".
|
||||
return seastar::format("org.apache.cassandra.db.marshal.VectorType({}, {})", type->name(), dimension);
|
||||
}
|
||||
|
||||
vector_type_impl::vector_type_impl(data_type elements, size_t dimension)
|
||||
vector_type_impl::vector_type_impl(data_type elements, vector_dimension_t dimension)
|
||||
: concrete_type(kind::vector, make_name(elements, dimension),
|
||||
elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value()*dimension):std::nullopt),
|
||||
elements->value_length_if_fixed() ? std::optional(elements->value_length_if_fixed().value() * dimension) : std::nullopt),
|
||||
_elements_type(elements), _dimension(dimension) {
|
||||
_contains_set_or_map = _elements_type->contains_set_or_map();
|
||||
}
|
||||
|
||||
shared_ptr<const vector_type_impl>
|
||||
vector_type_impl::get_instance(data_type elements, size_t dimension) {
|
||||
vector_type_impl::get_instance(data_type elements, vector_dimension_t dimension) {
|
||||
return intern::get_instance(elements, dimension);
|
||||
}
|
||||
|
||||
@@ -1681,7 +1681,7 @@ static void serialize_vector(const vector_type_impl& type, const vector_type_imp
|
||||
}
|
||||
|
||||
std::strong_ordering
|
||||
vector_type_impl::compare_vectors(data_type elements, size_t dimension, managed_bytes_view o1, managed_bytes_view o2) {
|
||||
vector_type_impl::compare_vectors(data_type elements, vector_dimension_t dimension, managed_bytes_view o1, managed_bytes_view o2) {
|
||||
if (o1.empty()) {
|
||||
return o2.empty() ? std::strong_ordering::equal : std::strong_ordering::less;
|
||||
} else if (o2.empty()) {
|
||||
|
||||
@@ -176,6 +176,8 @@ struct timeuuid_native_type {
|
||||
|
||||
using data_type = shared_ptr<const abstract_type>;
|
||||
|
||||
using vector_dimension_t = uint32_t;
|
||||
|
||||
template <typename T>
|
||||
const T& value_cast(const data_value& value);
|
||||
|
||||
|
||||
@@ -14,20 +14,20 @@
|
||||
#include "vint-serialization.hh"
|
||||
|
||||
class vector_type_impl : public concrete_type<std::vector<data_value>> {
|
||||
using intern = type_interning_helper<vector_type_impl, data_type, size_t>;
|
||||
using intern = type_interning_helper<vector_type_impl, data_type, vector_dimension_t>;
|
||||
protected:
|
||||
data_type _elements_type;
|
||||
size_t _dimension;
|
||||
vector_dimension_t _dimension;
|
||||
public:
|
||||
vector_type_impl(data_type elements_type, size_t dimension);
|
||||
static shared_ptr<const vector_type_impl> get_instance(data_type type, size_t dimension);
|
||||
vector_type_impl(data_type elements_type, vector_dimension_t dimension);
|
||||
static shared_ptr<const vector_type_impl> get_instance(data_type type, vector_dimension_t dimension);
|
||||
data_type get_elements_type() const {
|
||||
return _elements_type;
|
||||
}
|
||||
size_t get_dimension() const {
|
||||
vector_dimension_t get_dimension() const {
|
||||
return _dimension;
|
||||
}
|
||||
static std::strong_ordering compare_vectors(data_type elements_comparator, size_t dimension,
|
||||
static std::strong_ordering compare_vectors(data_type elements_comparator, vector_dimension_t dimension,
|
||||
managed_bytes_view o1, managed_bytes_view o2);
|
||||
|
||||
std::vector<managed_bytes> split_fragmented(FragmentedView auto v) const {
|
||||
@@ -94,7 +94,7 @@ public:
|
||||
return ret;
|
||||
}
|
||||
private:
|
||||
static sstring make_name(data_type type, size_t dimension);
|
||||
static sstring make_name(data_type type, vector_dimension_t dimension);
|
||||
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user