diff --git a/cql3/restrictions/statement_restrictions.cc b/cql3/restrictions/statement_restrictions.cc index d0027cd37e..de82680651 100644 --- a/cql3/restrictions/statement_restrictions.cc +++ b/cql3/restrictions/statement_restrictions.cc @@ -232,6 +232,7 @@ data_type type(const predicate& p) { return std::visit( overloaded_functor{ + [] (const on_row&) { return boolean_type; }, // Not true, but the type won't be used. [] (const on_column& oc) { return oc.column->type->without_reversed().shared_from_this(); }, [] (const on_partition_key_token&) { return long_type; }, [] (const on_clustering_key_prefix&) -> data_type { on_internal_error(rlogger, "type: asked for clustering key prefix type"); }, @@ -281,9 +282,10 @@ require_on_single_column(const predicate& p) { on_internal_error(rlogger, "require_on_single_column: predicate is not on a single column"); } -/// Given an expression and a column definition , builds a function that returns -/// the set of all column values that would satisfy the expression. The _token_values variant finds -/// matching values for the partition token function call instead of the column. +/// Given an expression, decompose it into a set of predicates, on individual columns, +/// the table's tokens, or multiple columns. A predicate may know how to solve for +/// the set of all column values that would satisfy the expression, treated a a boolean +/// predicate on the column. If it does, the .solve_for member is set. /// /// An expression restricts possible values of a column or token: /// - `A>5` restricts A from below @@ -300,44 +302,62 @@ require_on_single_column(const predicate& p) { // The schema is needed to find out whether a call to token() function represents // the partition token. static -solve_for_t -possible_lhs_values(const column_definition* cdef, - const expression& expr, - const schema* table_schema_opt) { - const auto type = cdef ? &cdef->type->without_reversed() : long_type.get(); +std::vector +to_predicates( + const expression& expr, + const schema* table_schema_opt) { + static auto to_vector = [] (predicate p) -> std::vector { + return {std::move(p)}; + }; + static auto cannot_solve = [] (const expression& e) -> std::vector { + return to_vector(predicate{ + .solve_for = nullptr, + .filter = e, + .on = on_row{}, + }); + }; + static auto cannot_solve_on_column = [] (const expression& e, const column_definition* cdef) -> std::vector { + return to_vector(predicate{ + .solve_for = nullptr, + .filter = e, + .on = on_column{cdef}, + }); + }; return expr::visit(overloaded_functor{ - [] (const constant& constant_val) -> solve_for_t { + [] (const constant& constant_val) -> std::vector { std::optional bool_val = get_bool_value(constant_val); if (bool_val.has_value()) { - return *bool_val + auto solve = *bool_val ? solve_for_t([] (const query_options&) { return unbounded_value_set; }) : solve_for_t([] (const query_options&) { return empty_value_set; }); + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = constant_val, + .on = on_row{}, + }); } - return nullptr; - }, - [&] (const conjunction& conj) -> solve_for_t { - auto children = - conj.children - | std::views::transform([&] (const expression& e) { - return possible_lhs_values(cdef, e, table_schema_opt); - }) - | std::ranges::to(); - return [children, type] (const query_options& options) -> value_set { - return std::ranges::fold_left(children, unbounded_value_set, [&](value_set&& acc, const solve_for_t& child) { - return intersection( - std::move(acc), child(options), type); + return to_vector(predicate{ + .solve_for = [] (const query_options&) { return unbounded_value_set; }, + .filter = constant_val, + .on = on_row{}, }); - }; }, - [&] (const binary_operator& oper) -> solve_for_t { + [&] (const conjunction& conj) -> std::vector { + std::vector ret; + for (auto& pa : conj.children) { + auto p = to_predicates(pa, table_schema_opt); + ret.insert(ret.end(), p.begin(), p.end()); + } + return ret; + }, + [&] (const binary_operator& oper) -> std::vector { return expr::visit(overloaded_functor{ - [&] (const column_value& col) -> solve_for_t { - if (!cdef || cdef != col.col) { - return [] (const query_options&) { return unbounded_value_set; }; - } + [&] (const column_value& col) -> std::vector { + auto cdef = col.col; + auto type = &cdef->type->without_reversed(); if (is_compare(oper.op)) { - return [oper] (const query_options& options) { + auto solve = [oper] (const query_options& options) { managed_bytes_opt val = evaluate(oper.rhs, options).to_managed_bytes_opt(); if (!val) { return empty_value_set; // All NULL comparisons fail; no column values match. @@ -345,30 +365,44 @@ possible_lhs_values(const column_definition* cdef, return oper.op == oper_t::EQ ? value_set(value_list{*val}) : to_range(oper.op, std::move(*val)); }; + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_column{col.col}, + .is_singleton = (oper.op == oper_t::EQ), + }); } else if (oper.op == oper_t::IN) { - return [oper, type, cdef] (const query_options& options) { + auto solve = [oper, type, cdef] (const query_options& options) { return get_IN_values(oper.rhs, options, type->as_less_comparator(), cdef->name_as_text()); }; + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_column{col.col}, + .is_singleton = false, + }); } else if (oper.op == oper_t::CONTAINS || oper.op == oper_t::CONTAINS_KEY) { - return [oper] (const query_options& options) { + auto solve = [oper] (const query_options& options) { managed_bytes_opt val = evaluate(oper.rhs, options).to_managed_bytes_opt(); if (!val) { return empty_value_set; // All NULL comparisons fail; no column values match. } return value_set(value_list{*val}); }; + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_column{col.col}, + .is_singleton = false, + }); } - return nullptr; + return cannot_solve_on_column(oper, col.col); }, - [&] (const subscript& s) -> solve_for_t { + [&] (const subscript& s) -> std::vector { const column_value& col = get_subscripted_column(s); - if (!cdef || cdef != col.col) { - return [] (const query_options&) { return unbounded_value_set; }; - } - if (oper.op == oper_t::EQ) { - return [s, oper] (const query_options& options) { + auto solve = [s, oper] (const query_options& options) { managed_bytes_opt sval = evaluate(s.sub, options).to_managed_bytes_opt(); if (!sval) { return empty_value_set; // NULL can't be a map key @@ -382,29 +416,54 @@ possible_lhs_values(const column_definition* cdef, managed_bytes val = tuple_type_impl::build_value_fragmented(elements); return value_set(value_list{val}); }; + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_column{col.col}, + .is_singleton = true, + }); } - return nullptr; + return cannot_solve_on_column(oper, col.col); }, - [&] (const tuple_constructor& tuple) -> solve_for_t { - return [cdef] (const query_options& options) -> value_set { - on_internal_error(rlogger, - fmt::format("possible_lhs_values: trying to solve for {} on tuple inequality", - cdef ? "single column" : "token")); - }; - }, - [&] (const function_call& token_fun_call) -> solve_for_t { - if (!is_partition_token_for_schema(token_fun_call, *table_schema_opt)) { - return nullptr; + [&] (const tuple_constructor& tuple) -> std::vector { + auto columns = tuple.elements + | std::views::transform([] (const expression& e) { return as(e).col; }) + | std::ranges::to(); + for (unsigned i = 0; i < columns.size(); ++i) { + if (!columns[i]->is_clustering_key() || columns[i]->position() != i) { + on_internal_error(rlogger, "to_predicates: multi-column relation not on a clustering key prefix"); + } } - - if (cdef) { - return [] (const query_options&) -> value_set { return unbounded_value_set; }; + // The solve_for lambda is only correct for EQ; other operators + // (IN, slices) are handled directly by + // build_get_multi_column_clustering_bounds_fn() which bypasses + // solve_for and evaluates the binary_operator's RHS itself. + solve_for_t solve = nullptr; + if (oper.op == oper_t::EQ) { + solve = [oper] (const query_options& options) { + managed_bytes_opt val = evaluate(oper.rhs, options).to_managed_bytes_opt(); + if (!val) { + return empty_value_set; // All NULL comparisons fail; no column values match. + } + return value_set(value_list{*val}); + }; + } + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_clustering_key_prefix{std::move(columns)}, + .is_singleton = oper.op == oper_t::EQ, + }); + }, + [&] (const function_call& token_fun_call) -> std::vector { + if (!is_partition_token_for_schema(token_fun_call, *table_schema_opt)) { + return cannot_solve(oper); } if (!(oper.op == oper_t::EQ || is_slice(oper.op))) { - return nullptr; + return cannot_solve(oper); } - return [oper] (const query_options& options) -> value_set { + auto solve = [oper] (const query_options& options) -> value_set { auto val = evaluate(oper.rhs, options).to_managed_bytes_opt(); if (!val) { return empty_value_set; // All NULL comparisons fail; no token values match. @@ -428,87 +487,157 @@ possible_lhs_values(const column_definition* cdef, } throw std::logic_error(format("get_token_interval unexpected operator {}", oper.op)); }; + return to_vector(predicate{ + .solve_for = std::move(solve), + .filter = oper, + .on = on_partition_key_token{table_schema_opt}, + .is_singleton = (oper.op == oper_t::EQ), + }); }, - [&] (const binary_operator&) -> solve_for_t { - return nullptr; + [&] (const binary_operator&) -> std::vector { + return cannot_solve(oper); }, - [&] (const conjunction&) -> solve_for_t { - return nullptr; + [&] (const conjunction&) -> std::vector { + return cannot_solve(oper); }, - [] (const constant&) -> solve_for_t { - return nullptr; + [&] (const constant&) -> std::vector { + return cannot_solve(oper); }, - [] (const unresolved_identifier&) -> solve_for_t { - return nullptr; + [&] (const unresolved_identifier&) -> std::vector { + return cannot_solve(oper); }, - [] (const column_mutation_attribute&) -> solve_for_t { - return nullptr; + [&] (const column_mutation_attribute&) -> std::vector { + return cannot_solve(oper); }, - [] (const cast&) -> solve_for_t { - return nullptr; + [&] (const cast&) -> std::vector { + return cannot_solve(oper); }, - [] (const field_selection&) -> solve_for_t { - return nullptr; + [&] (const field_selection&) -> std::vector { + return cannot_solve(oper); }, - [] (const bind_variable&) -> solve_for_t { - return nullptr; + [&] (const bind_variable&) -> std::vector { + return cannot_solve(oper); }, - [] (const untyped_constant&) -> solve_for_t { - return nullptr; + [&] (const untyped_constant&) -> std::vector { + return cannot_solve(oper); }, - [] (const collection_constructor&) -> solve_for_t { - return nullptr; + [&] (const collection_constructor&) -> std::vector { + return cannot_solve(oper); }, - [] (const usertype_constructor&) -> solve_for_t { - return nullptr; + [&] (const usertype_constructor&) -> std::vector { + return cannot_solve(oper); }, - [] (const temporary&) -> solve_for_t { - return nullptr; + [&] (const temporary&) -> std::vector { + return cannot_solve(oper); }, }, oper.lhs); }, - [] (const column_value&) -> solve_for_t { - return nullptr; + [] (const column_value& cv) -> std::vector { + return cannot_solve(cv); }, - [] (const subscript&) -> solve_for_t { - return nullptr; + [] (const subscript& s) -> std::vector { + return cannot_solve(s); }, - [] (const unresolved_identifier&) -> solve_for_t { - return nullptr; + [] (const unresolved_identifier& ui) -> std::vector { + return cannot_solve(ui); }, - [] (const column_mutation_attribute&) -> solve_for_t { - return nullptr; + [] (const column_mutation_attribute& cma) -> std::vector { + return cannot_solve(cma); }, - [] (const function_call&) -> solve_for_t { - return nullptr; + [] (const function_call& fc) -> std::vector { + return cannot_solve(fc); }, - [] (const cast&) -> solve_for_t { - return nullptr; + [] (const cast& c) -> std::vector { + return cannot_solve(c); }, - [] (const field_selection&) -> solve_for_t { - return nullptr; + [] (const field_selection& fs) -> std::vector { + return cannot_solve(fs); }, - [] (const bind_variable&) -> solve_for_t { - return nullptr; + [] (const bind_variable& bv) -> std::vector { + return cannot_solve(bv); }, - [] (const untyped_constant&) -> solve_for_t { - return nullptr; + [] (const untyped_constant& uc) -> std::vector { + return cannot_solve(uc); }, - [] (const tuple_constructor&) -> solve_for_t { - return nullptr; + [] (const tuple_constructor& tc) -> std::vector { + return cannot_solve(tc); }, - [] (const collection_constructor&) -> solve_for_t { - return nullptr; + [] (const collection_constructor& cc) -> std::vector { + return cannot_solve(cc); }, - [] (const usertype_constructor&) -> solve_for_t { - return nullptr; + [] (const usertype_constructor& uc) -> std::vector { + return cannot_solve(uc); }, - [] (const temporary&) -> solve_for_t { - return nullptr; + [] (const temporary& t) -> std::vector { + return cannot_solve(t); }, }, expr); } +// Convert an expression to a predicate on a column. If cdef is nullptr, the predicate +// is on the partition key token. +static +predicate +to_predicate_on_column( + const expression& expr, + const column_definition* cdef, + const schema* table_schema_opt) { + auto predicates = to_predicates(expr, table_schema_opt); + using on_t = std::variant< + on_row, // cannot determine, so predicate is on entire row + on_column, // solving for a single column: e.g. c1 = 3 + on_partition_key_token, // solving for the token, e.g. token(pk1, pk2) >= :var + on_clustering_key_prefix // solving for a clustering key prefix: e.g. (ck1, ck2) >= (3, 4) + >; + auto target = cdef ? on_t(on_column{cdef}) : on_t(on_partition_key_token{table_schema_opt}); + auto collected = std::vector{}; + for (auto& predicate : predicates) { + if (predicate.on == target) { + collected.push_back(std::move(predicate)); + continue; + } + } + if (collected.empty()) { + on_internal_error(rlogger, "to_predicate_on_column: no predicates found"); + } + auto ret = std::ranges::fold_left_first( + collected | std::views::as_rvalue, + make_conjunction + ); + if (!ret) { + on_internal_error(rlogger, "to_predicate_on_column: no predicates found"); + } + return std::move(*ret); +} + +// Convert an expression to a predicate on a column. If cdef is nullptr, the predicate +// is on the partition key token. +static +predicate +to_predicate_on_clustering_key_prefix( + const expression& expr, + const schema* table_schema_opt) { + auto predicates = to_predicates(expr, table_schema_opt); + std::vector collected; + for (auto& predicate : predicates) { + if (std::holds_alternative(predicate.on)) { + collected.push_back(std::move(predicate)); + continue; + } + } + if (collected.empty()) { + on_internal_error(rlogger, "to_predicate_on_clustering_key_prefix: no predicates found"); + } + auto ret = std::ranges::fold_left_first( + collected | std::views::as_rvalue, + make_conjunction + ); + if (!ret) { + on_internal_error(rlogger, "to_predicate_on_clustering_key_prefix: no predicates found"); + } + return std::move(*ret); +} + interval to_range(const value_set& s) { return std::visit(overloaded_functor{ [] (const interval& r) { return r; }, @@ -844,12 +973,7 @@ bool is_empty_restriction(const expression& e) { static std::function build_value_for_fn(const column_definition& cdef, const expression& e, const schema& s) { - auto ac = predicate{ - .solve_for = possible_lhs_values(&cdef, e, &s), - .filter = e, - .on = on_column{&cdef}, - .is_singleton = false, // Code below assumes 0 or 1 results. - }; + auto ac = to_predicate_on_column(e, &cdef, &s); return [ac] (const query_options& options) -> bytes_opt { value_set possible_vals = solve(ac, options); return std::visit(overloaded_functor { @@ -948,12 +1072,7 @@ static partition_range_restrictions extract_partition_range( auto s = &cv; with_current_binary_operator(*this, [&] (const binary_operator& b) { if (s->col->is_partition_key() && (b.op == oper_t::EQ || b.op == oper_t::IN)) { - auto a = predicate{ - .solve_for = possible_lhs_values(s->col, b, table_schema.get()), - .filter = b, - .on = on_column{s->col}, - .is_singleton = b.op == oper_t::EQ, - }; + auto a = to_predicate_on_column(b, s->col, table_schema.get()); const auto [it, inserted] = single_column.try_emplace(s->col, std::move(a)); if (!inserted) { it->second = make_conjunction(std::move(it->second), std::move(a)); @@ -1022,13 +1141,7 @@ static partition_range_restrictions extract_partition_range( if (v.tokens) { return token_range_restrictions{ - .token_restrictions = predicate{ - // It's not really a column, but... - .solve_for = possible_lhs_values(/* col */ nullptr, *v.tokens, schema.get()), - .filter = *v.tokens, - .on = on_partition_key_token{schema.get()}, - .is_singleton = false, // It could return a single token, but it's not important to track it - }, + .token_restrictions = to_predicate_on_column(*v.tokens, nullptr, schema.get()), }; } if (v.single_column.size() == schema->partition_key_size()) { @@ -1079,13 +1192,7 @@ static std::vector extract_clustering_prefix_restrictions( } } with_current_binary_operator(*this, [&] (const binary_operator& b) { - multi.push_back(predicate{ - .solve_for = possible_lhs_values(/* col */ nullptr, b, table_schema.get()), - .filter = b, - .on = on_clustering_key_prefix{prefix}, - .is_singleton = false, - .comparable = false, - }); + multi.push_back(to_predicate_on_clustering_key_prefix(b, table_schema.get())); }); } @@ -1093,12 +1200,7 @@ static std::vector extract_clustering_prefix_restrictions( auto s = &cv; with_current_binary_operator(*this, [&] (const binary_operator& b) { if (s->col->is_clustering_key()) { - auto a = predicate{ - .solve_for = possible_lhs_values(s->col, b, table_schema.get()), - .filter = b, - .on = on_column{s->col}, - .is_singleton = b.op == oper_t::EQ, - }; + auto a = to_predicate_on_column(b, s->col, table_schema.get()); const auto [it, inserted] = single.try_emplace(s->col, std::move(a)); if (!inserted) { it->second = make_conjunction(std::move(it->second), std::move(a)); @@ -1112,12 +1214,7 @@ static std::vector extract_clustering_prefix_restrictions( with_current_binary_operator(*this, [&] (const binary_operator& b) { if (cval.col->is_clustering_key()) { - auto a = predicate{ - .solve_for = possible_lhs_values(cval.col, b, table_schema.get()), - .filter = b, - .on = on_column{cval.col}, - .is_singleton = b.op == oper_t::EQ, - }; + auto a = to_predicate_on_column(b, cval.col, table_schema.get()); const auto [it, inserted] = single.try_emplace(cval.col, std::move(a)); if (!inserted) { it->second = make_conjunction(std::move(it->second), std::move(a)); @@ -2497,10 +2594,10 @@ static std::vector get_index_v1_token_range_clustering_ const column_definition& token_column, const predicate& token_restriction) { - // A workaround in order to make possible_lhs_values work properly. - // possible_lhs_values looks at the column type and uses this type's comparator. + // A workaround in order to make to_predicate work properly. + // to_predicate looks at the column type and uses this type's comparator. // This is a problem because when using blob's comparator, -4 is greater than 4. - // This makes possible_lhs_values think that an expression like token(p) > -4 and token(p) < 4 + // This makes to_predicate think that an expression like token(p) > -4 and token(p) < 4 // is impossible to fulfill. // Create a fake token column with the type set to bigint, translate the restriction to use this column // and use this restriction to calculate possible lhs values. @@ -2876,12 +2973,7 @@ void statement_restrictions::prepare_indexed_global(const schema& idx_tbl_schema // This means that p1 and p2 can have many different values (token is a hash, can have collisions). // Clustering prefix ends after token_restriction, all further restrictions have to be filtered. expr::expression token_restriction = replace_partition_token(_partition_key_restrictions, token_column, *_schema); - _idx_tbl_ck_prefix = std::vector{predicate{ - .solve_for = possible_lhs_values(token_column, token_restriction, _schema.get()), - .filter = token_restriction, - .on = on_column{token_column}, - .is_singleton = false, // FIXME: could be a singleton token. Not very important. - }}; + _idx_tbl_ck_prefix = std::vector{to_predicate_on_column(token_restriction, token_column, _schema.get())}; return; } @@ -2980,12 +3072,7 @@ void statement_restrictions::prepare_indexed_local(const schema& idx_tbl_schema) // Translate the restriction to use column from the index schema and add it expr::expression replaced_idx_restriction = replace_column_def(idx_col_restriction_expr, &indexed_column); - _idx_tbl_ck_prefix->push_back(predicate{ - .solve_for = possible_lhs_values(&indexed_column, replaced_idx_restriction, _schema.get()), - .filter = replaced_idx_restriction, - .on = on_column{&indexed_column}, - .is_singleton = false, // Could be true, but not important. - }); + _idx_tbl_ck_prefix->push_back(to_predicate_on_column(replaced_idx_restriction, &indexed_column, _schema.get())); // Add restrictions for the clustering key add_clustering_restrictions_to_idx_ck_prefix(idx_tbl_schema); @@ -3004,12 +3091,7 @@ void statement_restrictions::add_clustering_restrictions_to_idx_ck_prefix(const const auto col = expr::as(any_binop->lhs).col; auto col_in_index = idx_tbl_schema.get_column_definition(col->name()); auto replaced = replace_column_def(e.filter, col_in_index); - auto a = predicate{ - .solve_for = possible_lhs_values(col_in_index, replaced, &idx_tbl_schema), - .filter = replaced, - .on = on_column{col_in_index}, - .is_singleton = false, // FIXME: could be a singleton token. Not very important. - }; + auto a = to_predicate_on_column(replaced, col_in_index, &idx_tbl_schema); _idx_tbl_ck_prefix->push_back(std::move(a)); } } diff --git a/cql3/restrictions/statement_restrictions.hh b/cql3/restrictions/statement_restrictions.hh index 37b96f31ac..74f08ffc1f 100644 --- a/cql3/restrictions/statement_restrictions.hh +++ b/cql3/restrictions/statement_restrictions.hh @@ -35,6 +35,10 @@ using value_set = std::variant>; // clause to TRUE. using solve_for_t = std::function; +struct on_row { + bool operator==(const on_row&) const = default; +}; + struct on_column { const column_definition* column; @@ -65,6 +69,7 @@ struct predicate { expr::expression filter; // What column the predicate can be solved for std::variant< + on_row, // cannot determine, so predicate is on entire row on_column, // solving for a single column: e.g. c1 = 3 on_partition_key_token, // solving for the token, e.g. token(pk1, pk2) >= :var on_clustering_key_prefix // solving for a clustering key prefix: e.g. (ck1, ck2) >= (3, 4)