diff --git a/cql3/restrictions/statement_restrictions.cc b/cql3/restrictions/statement_restrictions.cc index e34e6ca5b0..c48c02afe1 100644 --- a/cql3/restrictions/statement_restrictions.cc +++ b/cql3/restrictions/statement_restrictions.cc @@ -1082,126 +1082,6 @@ void with_current_binary_operator( func(*visitor.current_binary_operator); } -/// Every token, or if no tokens, an EQ/IN of every single PK column. -static partition_range_restrictions extract_partition_range( - std::span where_clause, schema_ptr schema) { - using namespace expr; - struct extract_partition_range_visitor { - schema_ptr table_schema; - std::optional tokens; - std::unordered_map single_column; - const binary_operator* current_binary_operator = nullptr; - - void operator()(const conjunction& c) { - std::ranges::for_each(c.children, [this] (const expression& child) { expr::visit(*this, child); }); - } - - void operator()(const binary_operator& b) { - if (current_binary_operator) { - throw std::logic_error("Nested binary operators are not supported"); - } - current_binary_operator = &b; - expr::visit(*this, b.lhs); - current_binary_operator = nullptr; - } - - void operator()(const function_call& token_fun_call) { - if (!is_partition_token_for_schema(token_fun_call, *table_schema)) { - on_internal_error(rlogger, "extract_partition_range(function_call)"); - } - - with_current_binary_operator(*this, [&] (const binary_operator& b) { - if (tokens) { - tokens = make_conjunction(std::move(*tokens), b); - } else { - tokens = b; - } - }); - } - - void operator()(const column_value& cv) { - auto s = &cv; - with_current_binary_operator(*this, [&] (const binary_operator& b) { - if (s->col->is_partition_key() && (b.op == oper_t::EQ || b.op == oper_t::IN)) { - auto a = to_predicate_on_column(b, s->col, table_schema.get()); - const auto [it, inserted] = single_column.try_emplace(s->col, std::move(a)); - if (!inserted) { - it->second = make_conjunction(std::move(it->second), std::move(a)); - } - } - }); - } - - void operator()(const tuple_constructor& s) { - // Partition key columns are not legal in tuples, so ignore tuples. - } - - void operator()(const subscript& sub) { - const column_value& cval = get_subscripted_column(sub.val); - if (cval.col->is_partition_key()) { - on_internal_error(rlogger, "extract_partition_range(subscript)"); - } - } - - void operator()(const constant&) {} - - void operator()(const unresolved_identifier&) { - on_internal_error(rlogger, "extract_partition_range(unresolved_identifier)"); - } - - void operator()(const column_mutation_attribute&) { - on_internal_error(rlogger, "extract_partition_range(column_mutation_attribute)"); - } - - void operator()(const cast&) { - on_internal_error(rlogger, "extract_partition_range(cast)"); - } - - void operator()(const field_selection&) { - on_internal_error(rlogger, "extract_partition_range(field_selection)"); - } - - void operator()(const bind_variable&) { - on_internal_error(rlogger, "extract_partition_range(bind_variable)"); - } - - void operator()(const untyped_constant&) { - on_internal_error(rlogger, "extract_partition_range(untyped_constant)"); - } - - void operator()(const collection_constructor&) { - on_internal_error(rlogger, "extract_partition_range(collection_constructor)"); - } - - void operator()(const usertype_constructor&) { - on_internal_error(rlogger, "extract_partition_range(usertype_constructor)"); - } - - void operator()(const temporary&) { - on_internal_error(rlogger, "extract_partition_range(temporary)"); - } - }; - - extract_partition_range_visitor v { - .table_schema = schema - }; - - for (auto& e : where_clause) { - expr::visit(v, e); - } - - if (v.tokens) { - return token_range_restrictions{ - .token_restrictions = to_predicate_on_column(*v.tokens, nullptr, schema.get()), - }; - } - if (v.single_column.size() == schema->partition_key_size()) { - return single_column_partition_range_restrictions{ - .per_column_restrictions = v.single_column | std::views::values | std::ranges::to(), - }; - } - return no_partition_range_restrictions{}; -} /// Extracts where_clause atoms with clustering-column LHS and copies them to a vector. These elements define the /// boundaries of any clustering slice that can possibly meet where_clause. This vector can be calculated before @@ -1390,6 +1270,8 @@ statement_restrictions::statement_restrictions(private_tag, const predicate* first_mc_pred = nullptr; bool pk_is_empty = true; bool has_token = false; + std::optional token_pred; + std::unordered_map pk_range_preds; for (auto& pred : predicates) { if (pred.is_not_null_single_column) { auto* col = require_on_single_column(pred); @@ -1468,6 +1350,11 @@ statement_restrictions::statement_restrictions(private_tag, _partition_key_restrictions = expr::make_conjunction(_partition_key_restrictions, pred.filter); pk_is_empty = false; has_token = true; + if (token_pred) { + token_pred = make_conjunction(std::move(*token_pred), pred); + } else { + token_pred = pred; + } } else if (std::holds_alternative(pred.on)) { const column_definition* def = std::get(pred.on).column; if (def->is_partition_key()) { @@ -1493,6 +1380,12 @@ statement_restrictions::statement_restrictions(private_tag, auto [it, inserted] = _single_column_partition_key_restrictions.try_emplace(def, expr::conjunction{}); it->second = expr::make_conjunction(std::move(it->second), pred.filter); } + if (pred.equality || pred.is_in) { + auto [it, inserted] = pk_range_preds.try_emplace(def, pred); + if (!inserted) { + it->second = make_conjunction(std::move(it->second), pred); + } + } _partition_range_is_simple &= !pred.is_in; } else if (def->is_clustering_key()) { if (has_mc_clustering) { @@ -1546,7 +1439,15 @@ statement_restrictions::statement_restrictions(private_tag, } if (!_where.empty()) { _clustering_prefix_restrictions = extract_clustering_prefix_restrictions(_where, _schema); - _partition_range_restrictions = extract_partition_range(_where, _schema); + if (token_pred) { + _partition_range_restrictions = token_range_restrictions{ + .token_restrictions = std::move(*token_pred), + }; + } else if (pk_range_preds.size() == _schema->partition_key_size()) { + _partition_range_restrictions = single_column_partition_range_restrictions{ + .per_column_restrictions = std::move(pk_range_preds) | std::views::values | std::ranges::to(), + }; + } } _has_multi_column = has_mc_clustering; if (_check_indexes) {