diff --git a/cql3/expr/expression.cc b/cql3/expr/expression.cc index 852a3c48c9..ee1651b6b3 100644 --- a/cql3/expr/expression.cc +++ b/cql3/expr/expression.cc @@ -75,6 +75,7 @@ static cql3::raw_value do_evaluate(const tuple_constructor&, const evaluation_in static cql3::raw_value do_evaluate(const collection_constructor&, const evaluation_inputs&); static cql3::raw_value do_evaluate(const usertype_constructor&, const evaluation_inputs&); static cql3::raw_value do_evaluate(const function_call&, const evaluation_inputs&); +static cql3::raw_value do_evaluate(const unary_operator&, const evaluation_inputs&); namespace { @@ -839,6 +840,9 @@ auto fmt::formatter::format(const cql3::expr::e }, [&] (const temporary& t) { out = fmt::format_to(out, "@temporary{}", t.index); + }, + [&] (const unary_operator& uo) { + out = fmt::format_to(out, "({}{})", uo.op, to_printer(uo.operand)); } }, pr.expr_to_print); return out; @@ -972,6 +976,9 @@ bool recurse_until(const expression& e, const noncopyable_function expression { + return unary_operator{uo.op, recurse(uo.operand)}; + }, [&] (LeafExpression auto const& e) -> expression { return e; }, @@ -1454,6 +1464,51 @@ do_evaluate(const temporary& t, const evaluation_inputs& inputs) { return inputs.temporaries[t.index]; } +static +cql3::raw_value +do_evaluate(const unary_operator& uo, const evaluation_inputs& inputs) { + // For now, this is do-nothing switch() supporting only the NEG operator. + // It will ask the compiler to warn us if we ever add a new type of unary + // operator and forget to update this function to handle it. + switch (uo.op) { + case unary_oper_t::NEG: + break; + } + raw_value operand_val = evaluate(uo.operand, inputs); + if (operand_val.is_null()) { + return raw_value::make_null(); + } + const abstract_type& t = type_of(uo.operand)->without_reversed(); + bytes result = operand_val.view().with_linearized([&](bytes_view bv) -> bytes { + return visit(t, make_visitor( + [&] (const integer_type_impl& itype) -> bytes { + T v = value_cast(itype.deserialize(bv)); + T res; + if (__builtin_sub_overflow(T(0), v, &res)) { + throw exceptions::invalid_request_exception("Arithmetic negation overflow"); + } + return serialized(res); + }, + [&] (const floating_type_impl& ftype) -> bytes { + return serialized(-value_cast(ftype.deserialize(bv))); + }, + [&] (const varint_type_impl& vtype) -> bytes { + utils::multiprecision_int v = value_cast(vtype.deserialize(bv)); + return serialized(-v); + }, + [&] (const decimal_type_impl& dtype) -> bytes { + big_decimal v = value_cast(dtype.deserialize(bv)); + return serialized(-v); + }, + [&] (const abstract_type& atype) -> bytes { + throw exceptions::invalid_request_exception( + format("Arithmetic negation is not supported for type {}", atype.cql3_type_name())); + } + )); + }); + return raw_value::make_value(managed_bytes(result)); +} + cql3::raw_value evaluate(const expression& e, const evaluation_inputs& inputs) { return expr::visit([&] (const ExpressionElement auto& ee) -> cql3::raw_value { return do_evaluate(ee, inputs); @@ -2031,6 +2086,9 @@ void fill_prepare_context(expression& e, prepare_context& ctx) { [](untyped_constant&) {}, [](constant&) {}, [](temporary&) {}, + [&](unary_operator& uo) { + fill_prepare_context(uo.operand, ctx); + }, }, e); } @@ -2110,6 +2168,9 @@ type_of(const expression& e) { }, }, e.type); }, + [] (const unary_operator& uo) { + return type_of(uo.operand); + }, [] (const ExpressionElement auto& e) -> data_type { return e.type; } @@ -2407,6 +2468,9 @@ aggregation_depth(const cql3::expr::expression& e) { }, [] (const usertype_constructor& uc) { return max_over_range(uc.elements | std::views::values); + }, + [] (const unary_operator& uo) { + return aggregation_depth(uo.operand); } }, e); } @@ -2496,6 +2560,10 @@ levellize_aggregation_depth(const cql3::expr::expression& e, unsigned desired_de [&] (usertype_constructor uc) -> expression { recurse_over_range(uc.elements | std::views::values); return uc; + }, + [&] (unary_operator uo) -> expression { + recurse(uo.operand); + return uo; } }, e); } @@ -2709,3 +2777,13 @@ std::string_view fmt::formatter::to_string(const cql3::expr: } on_internal_error(cql3::expr::expr_logger, fmt::format("unexpected oper_t value {}", static_cast(op))); } + +std::string_view fmt::formatter::to_string(const cql3::expr::unary_oper_t& op) { + using cql3::expr::unary_oper_t; + + switch (op) { + case unary_oper_t::NEG: + return "-"; + } + on_internal_error(cql3::expr::expr_logger, fmt::format("unexpected unary_oper_t value {}", static_cast(op))); +} diff --git a/cql3/expr/expression.hh b/cql3/expr/expression.hh index 2f68ba217a..776bc0088c 100644 --- a/cql3/expr/expression.hh +++ b/cql3/expr/expression.hh @@ -57,6 +57,7 @@ struct allow_local_index_tag {}; using allow_local_index = bool_class; struct binary_operator; +struct unary_operator; struct conjunction; struct column_value; struct subscript; @@ -77,6 +78,7 @@ template concept ExpressionElement = std::same_as || std::same_as + || std::same_as || std::same_as || std::same_as || std::same_as @@ -97,6 +99,7 @@ template concept invocable_on_expression = std::invocable && std::invocable + && std::invocable && std::invocable && std::invocable && std::invocable @@ -117,6 +120,7 @@ template concept invocable_on_expression_ref = std::invocable && std::invocable + && std::invocable && std::invocable && std::invocable && std::invocable @@ -223,6 +227,9 @@ const column_value& get_subscripted_column(const expression&); enum class oper_t { EQ, NEQ, LT, LTE, GTE, GT, IN, NOT_IN, CONTAINS, CONTAINS_KEY, IS_NOT, LIKE, ADD, SUB }; +/// The operator of a unary expression. +enum class unary_oper_t { NEG }; + /// Describes the nature of clustering-key comparisons. Useful for implementing SCYLLA_CLUSTERING_BOUND. enum class comparison_order : char { cql, ///< CQL order. (a,b)>(1,1) is equivalent to a>1 OR (a=1 AND b>1). @@ -248,6 +255,15 @@ struct binary_operator { friend bool operator==(const binary_operator&, const binary_operator&) = default; }; +// A unary operation on an expression. +// Currently only negation (unary_oper_t::NEG) for numeric types is supported. +struct unary_operator { + unary_oper_t op; + expression operand; + + friend bool operator==(const unary_operator&, const unary_operator&) = default; +}; + // A conjunction of expressions separated by the AND keyword. // For example: "a < 3 AND col1 = ? AND pk IN (1, 2)" struct conjunction { @@ -469,7 +485,8 @@ struct expression::impl final { conjunction, binary_operator, column_value, unresolved_identifier, column_mutation_attribute, function_call, cast, field_selection, bind_variable, untyped_constant, constant, tuple_constructor, - collection_constructor, usertype_constructor, subscript, temporary>; + collection_constructor, usertype_constructor, subscript, temporary, + unary_operator>; variant_type v; impl(variant_type v) : v(std::move(v)) {} }; @@ -603,3 +620,16 @@ struct fmt::formatter { private: static std::string_view to_string(const cql3::expr::oper_t& op); }; + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + + template + auto format(const cql3::expr::unary_oper_t& op, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}", to_string(op)); + } + +private: + static std::string_view to_string(const cql3::expr::unary_oper_t& op); +}; diff --git a/cql3/expr/prepare_expr.cc b/cql3/expr/prepare_expr.cc index 6fb57d39c3..5c64aceef4 100644 --- a/cql3/expr/prepare_expr.cc +++ b/cql3/expr/prepare_expr.cc @@ -1344,6 +1344,19 @@ try_prepare_expression(const expression& expr, data_dictionary::database db, con } return result; }, + [&] (const unary_operator& uo) -> std::optional { + // Prepare the operand; unary_operator preserves its operand's type. + auto prepared_operand = try_prepare_expression(uo.operand, db, keyspace, schema_opt, receiver); + if (!prepared_operand) { + return std::nullopt; + } + unary_operator result{uo.op, std::move(*prepared_operand)}; + // Constant folding: if the operand is fully known, evaluate now. + if (is(result.operand)) { + return constant(evaluate(result, query_options::DEFAULT), type_of(result.operand)); + } + return result; + }, [&] (const conjunction& conj) -> std::optional { return prepare_conjunction(conj, db, keyspace, schema_opt, receiver); }, @@ -1455,6 +1468,9 @@ test_assignment(const expression& expr, data_dictionary::database db, const sstr [&] (const binary_operator&) -> test_result { on_internal_error(expr_logger, "binary_operators are not yet reachable via test_assignment()"); }, + [&] (const unary_operator&) -> test_result { + on_internal_error(expr_logger, "unary_operators are not yet reachable via test_assignment()"); + }, [&] (const conjunction&) -> test_result { on_internal_error(expr_logger, "conjunctions are not yet reachable via test_assignment()"); }, @@ -1550,6 +1566,9 @@ test_assignment_any_size_float_vector(const expression& expr) { [&] (const binary_operator&) -> test_result { return NOT_ASSIGNABLE; }, + [&] (const unary_operator&) -> test_result { + return NOT_ASSIGNABLE; + }, [&] (const conjunction&) -> test_result { return NOT_ASSIGNABLE; }, diff --git a/cql3/restrictions/statement_restrictions.cc b/cql3/restrictions/statement_restrictions.cc index b11d6a893a..eda7f685dd 100644 --- a/cql3/restrictions/statement_restrictions.cc +++ b/cql3/restrictions/statement_restrictions.cc @@ -521,6 +521,9 @@ to_predicates( [&] (const binary_operator&) -> std::vector { return cannot_solve(oper); }, + [&] (const unary_operator&) -> std::vector { + return cannot_solve(oper); + }, [&] (const conjunction&) -> std::vector { return cannot_solve(oper); }, @@ -556,6 +559,9 @@ to_predicates( }, }, oper.lhs); }, + [] (const unary_operator& uo) -> std::vector { + return cannot_solve(uo); + }, [] (const column_value& cv) -> std::vector { return cannot_solve(cv); }, diff --git a/cql3/selection/selectable.cc b/cql3/selection/selectable.cc index 9239761e6e..3aaf58bc88 100644 --- a/cql3/selection/selectable.cc +++ b/cql3/selection/selectable.cc @@ -40,6 +40,9 @@ selectable_processes_selection(const expr::expression& selectable) { [&] (const expr::binary_operator& conj) -> bool { on_internal_error(slogger, "no way to express 'SELECT a binop b' in the grammar yet"); }, + [&] (const expr::unary_operator&) -> bool { + on_internal_error(slogger, "no way to express 'SELECT unop a' in the grammar yet"); + }, [] (const expr::subscript&) -> bool { return true; }, diff --git a/test/boost/expr_test.cc b/test/boost/expr_test.cc index b7cc8d3700..d87269ee6e 100644 --- a/test/boost/expr_test.cc +++ b/test/boost/expr_test.cc @@ -5302,3 +5302,124 @@ BOOST_AUTO_TEST_CASE(evaluate_sub_reversed_type) { BOOST_REQUIRE_EQUAL(raw_to(result, int32_type), 7); } +// Helper: evaluate a NEG unary_operator with a constant operand. +static raw_value eval_neg(expression expr) { + return evaluate(unary_operator{unary_oper_t::NEG, std::move(expr)}, evaluation_inputs{}); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_fixed_width_integers) { + // int (32-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_int_const(5)), int32_type), -5); + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_int_const(-7)), int32_type), 7); + + // tinyint (8-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_tinyint_const(10)), byte_type), int8_t(-10)); + + // smallint (16-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_smallint_const(300)), short_type), int16_t(-300)); + + // bigint (64-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_bigint_const(1'000'000'000LL)), long_type), int64_t(-1'000'000'000LL)); + + // type_of a NEG expression is the type of the operand + expression neg_expr = unary_operator{unary_oper_t::NEG, make_int_const(1)}; + BOOST_REQUIRE(type_of(neg_expr) == int32_type); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_floating_point) { + // float (32-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_float_const(3.5f)), float_type), -3.5f); + + // double (64-bit) + BOOST_REQUIRE_EQUAL(raw_to(eval_neg(make_double_const(-2.5)), double_type), 2.5); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_varint) { + auto make_varint = [](int64_t v) -> constant { + return constant(raw_value::make_value(managed_bytes(varint_type->decompose(utils::multiprecision_int(v)))), varint_type); + }; + BOOST_REQUIRE_EQUAL( + raw_to(eval_neg(make_varint(42)), varint_type), + utils::multiprecision_int(-42)); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_decimal) { + auto make_decimal = [](const char* s) -> constant { + return constant(raw_value::make_value(managed_bytes(decimal_type->decompose(big_decimal(s)))), decimal_type); + }; + BOOST_REQUIRE_EQUAL( + raw_to(eval_neg(make_decimal("3.14")), decimal_type), + big_decimal("-3.14")); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_null_propagation) { + // NEG(null) → null for any numeric type + BOOST_REQUIRE(eval_neg(constant::make_null(int32_type)).is_null()); + BOOST_REQUIRE(eval_neg(constant::make_null(long_type)).is_null()); + BOOST_REQUIRE(eval_neg(constant::make_null(float_type)).is_null()); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_overflow) { + // Negating the minimum value of each signed integer type overflows + BOOST_REQUIRE_THROW( + eval_neg(make_int_const(std::numeric_limits::min())), + exceptions::invalid_request_exception); + BOOST_REQUIRE_THROW( + eval_neg(make_tinyint_const(std::numeric_limits::min())), + exceptions::invalid_request_exception); + BOOST_REQUIRE_THROW( + eval_neg(make_smallint_const(std::numeric_limits::min())), + exceptions::invalid_request_exception); + BOOST_REQUIRE_THROW( + eval_neg(make_bigint_const(std::numeric_limits::min())), + exceptions::invalid_request_exception); +} + +// varint supports arbitrary precision, so negating INT64_MIN doesn't overflow. +BOOST_AUTO_TEST_CASE(evaluate_neg_arbitrary_precision_no_overflow) { + utils::multiprecision_int min_val(std::numeric_limits::min()); + constant c(raw_value::make_value(managed_bytes(varint_type->decompose(min_val))), varint_type); + BOOST_REQUIRE_EQUAL( + raw_to(eval_neg(c), varint_type), + -min_val); +} + +BOOST_AUTO_TEST_CASE(evaluate_neg_non_numeric_type) { + BOOST_REQUIRE_THROW( + eval_neg(make_text_const("hello")), + exceptions::invalid_request_exception); + + constant bool_const(raw_value::make_value(managed_bytes(boolean_type->decompose(true))), boolean_type); + BOOST_REQUIRE_THROW( + eval_neg(bool_const), + exceptions::invalid_request_exception); +} + +// NEG unary_operator should print as "(-operand)" +BOOST_AUTO_TEST_CASE(evaluate_neg_printer) { + expression neg_expr = unary_operator{unary_oper_t::NEG, make_int_const(3)}; + BOOST_REQUIRE_EQUAL(expr_print(neg_expr), "(-3)"); +} + +// Arithmetic NEG on a column with a reversed type (DESC clustering key) +// must work correctly, and not think reverse(int) is an unsupported type. +BOOST_AUTO_TEST_CASE(evaluate_neg_reversed_type) { + // Schema with a DESC clustering key: equivalent to + // CREATE TABLE test_ks.test_cf (pk int, ck int, PRIMARY KEY (pk, ck)) WITH CLUSTERING ORDER BY (ck DESC); + schema_ptr test_schema = + schema_builder("test_ks", "test_cf") + .with_column("pk", int32_type, column_kind::partition_key) + .with_column("ck", reversed_type_impl::get_instance(int32_type), column_kind::clustering_key) + .build(); + + auto [inputs, inputs_data] = make_evaluation_inputs(test_schema, { + {"pk", make_int_raw(1)}, + {"ck", make_int_raw(5)}, + }); + + // -ck should equal -5 + expression ck_col = column_value(test_schema->get_column_definition("ck")); + expression neg_expr = unary_operator{unary_oper_t::NEG, ck_col}; + raw_value result = evaluate(neg_expr, inputs); + BOOST_REQUIRE_EQUAL(raw_to(result, int32_type), -5); +}