diff --git a/cql3/lists.hh b/cql3/lists.hh index 38750f85f9..cb20b112ea 100644 --- a/cql3/lists.hh +++ b/cql3/lists.hh @@ -106,6 +106,9 @@ public: virtual bool contains_bind_marker() const override; virtual void collect_marker_specification(variable_specifications& bound_names) const override; virtual shared_ptr bind(const query_options& options) override; + const std::vector>& get_elements() const { + return _elements; + } }; /** diff --git a/cql3/operator.hh b/cql3/operator.hh index 36ef8132ee..f4b4771e2d 100644 --- a/cql3/operator.hh +++ b/cql3/operator.hh @@ -82,6 +82,9 @@ public: // EQ, LT, LTE, GT, GTE, NEQ return _b < 5 || _b == 8; } + bool needs_filtering() const { + return (*this == CONTAINS) || (*this == CONTAINS_KEY) || (*this == LIKE); + } sstring to_string() const { return _text; } bool operator==(const operator_type& other) const { return this == &other; } bool operator!=(const operator_type& other) const { return this != &other; } diff --git a/cql3/restrictions/restriction.hh b/cql3/restrictions/restriction.hh index a40b73a0a0..6b03a6211d 100644 --- a/cql3/restrictions/restriction.hh +++ b/cql3/restrictions/restriction.hh @@ -41,19 +41,30 @@ #pragma once +#include +#include +#include #include #include +#include #include #include +#include "utils/overloaded_functor.hh" #include "cql3/query_options.hh" #include "cql3/term.hh" #include "cql3/statements/bound.hh" #include "index/secondary_index_manager.hh" +#include "query-result-reader.hh" +#include "range.hh" #include "types.hh" namespace cql3 { +namespace selection { +class selection; +} // namespace selection + namespace restrictions { struct allow_local_index_tag {}; @@ -91,6 +102,135 @@ struct conjunction { std::vector children; }; +/// Creates a conjunction of a and b. If either a or b is itself a conjunction, its children are inserted +/// directly into the resulting conjunction's children, flattening the expression tree. +extern expression make_conjunction(expression a, expression b); + +/// True iff restr is satisfied with respect to the row provided from a partition slice. +extern bool is_satisfied_by( + const expression& restr, + const std::vector& partition_key, const std::vector& clustering_key, + const query::result_row_view& static_row, const query::result_row_view* row, + const selection::selection&, const query_options&); + +/// True iff restr is satisfied with respect to the row provided from a mutation. +extern bool is_satisfied_by( + const expression& restr, + const schema& schema, const partition_key& key, const clustering_key_prefix& ckey, const row& cells, + const query_options& options, gc_clock::time_point now); + +/// Finds the first binary_operator in restr that represents a bound and returns its RHS as a tuple. If no +/// such binary_operator exists, returns an empty vector. The search is depth first. +extern std::vector first_multicolumn_bound(const expression&, const query_options&, statements::bound); + +/// A set of discrete values. +using value_list = std::vector; // Sorted and deduped using value comparator. + +/// General set of values. Empty set and single-element sets are always value_list. nonwrapping_range is +/// never singular and never has start > end. Universal set is a nonwrapping_range with both bounds null. +using value_set = std::variant>; + +/// A set of all column values that would satisfy an expression. If column is null, a set of all token values +/// that satisfy. +/// +/// An expression restricts possible values of a column or token: +/// - `A>5` restricts A from below +/// - `A>5 AND A>6 AND B<10 AND A=12 AND B>0` restricts A to 12 and B to between 0 and 10 +/// - `A IN (1, 3, 5)` restricts A to 1, 3, or 5 +/// - `A IN (1, 3, 5) AND A>3` restricts A to just 5 +/// - `A=1 AND A<=0` restricts A to an empty list; no value is able to satisfy the expression +/// - `A>=NULL` also restricts A to an empty list; all comparisons to NULL are false +/// - an expression without A "restricts" A to unbounded range +extern value_set possible_lhs_values(const column_definition*, const expression&, const query_options&); + +/// Turns value_set into a range, unless it's a multi-valued list (in which case this throws). +extern nonwrapping_range to_range(const value_set&); + +/// True iff expr references the function. +extern bool uses_function(const expression& expr, const sstring& ks_name, const sstring& function_name); + +/// True iff the index can support the entire expression. +extern bool is_supported_by(const expression&, const secondary_index::index&); + +/// True iff any of the indices from the manager can support the entire expression. If allow_local, use all +/// indices; otherwise, use only global indices. +extern bool has_supporting_index( + const expression&, const secondary_index::secondary_index_manager&, allow_local_index allow_local); + +extern sstring to_string(const expression&); + +extern std::ostream& operator<<(std::ostream&, const column_value&); + +extern std::ostream& operator<<(std::ostream&, const expression&); + +/// If there is a binary_operator atom b for which f(b) is true, returns it. Otherwise returns null. +template +const binary_operator* find_if(const expression& e, Fn f) { + return std::visit(overloaded_functor{ + [&] (const binary_operator& op) { return f(op) ? &op : nullptr; }, + [] (bool) -> const binary_operator* { return nullptr; }, + [&] (const conjunction& conj) -> const binary_operator* { + for (auto& child : conj.children) { + if (auto found = find_if(child, f)) { + return found; + } + } + return nullptr; + }, + }, e); +} + +/// Counts binary_operator atoms b for which f(b) is true. +template +size_t count_if(const expression& e, Fn f) { + return std::visit(overloaded_functor{ + [&] (const binary_operator& op) -> size_t { return f(op) ? 1 : 0; }, + [&] (const conjunction& conj) { + return std::accumulate(conj.children.cbegin(), conj.children.cend(), size_t{0}, + [&] (size_t acc, const expression& c) { return acc + count_if(c, f); }); + }, + [] (bool) -> size_t { return 0; }, + }, e); +} + +inline const binary_operator* find(const expression& e, const operator_type& op) { + return find_if(e, [&] (const binary_operator& o) { return *o.op == op; }); +} + +inline bool needs_filtering(const expression& e) { + return find_if(e, [] (const binary_operator& o) { return o.op->needs_filtering(); }); +} + +inline bool has_slice(const expression& e) { + return find_if(e, [] (const binary_operator& o) { return o.op->is_slice(); }); +} + +inline bool has_token(const expression& e) { + return find_if(e, [] (const binary_operator& o) { return std::holds_alternative(o.lhs); }); +} + +inline bool has_slice_or_needs_filtering(const expression& e) { + return find_if(e, [] (const binary_operator& o) { return o.op->is_slice() || o.op->needs_filtering(); }); +} + +/// True iff binary_operator involves a collection. +extern bool is_on_collection(const binary_operator&); + +/// Replaces every column_definition in an expression with this one. Throws if any LHS is not a single +/// column_value. +extern expression replace_column_def(const expression&, const column_definition*); + +/// Makes a binary_operator on a column_definition. +inline expression make_column_op(const column_definition* cdef, const operator_type& op, ::shared_ptr value) { + return binary_operator{std::vector{column_value(cdef)}, &op, std::move(value)}; +} + +inline const operator_type* pick_operator(statements::bound b, bool inclusive) { + return is_start(b) ? + (inclusive ? &operator_type::GTE : &operator_type::GT) : + (inclusive ? &operator_type::LTE : &operator_type::LT); +} + /** * Base class for Restrictions */ @@ -257,3 +397,33 @@ protected: } } + +/// Required for fmt::join() to work on expression. +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { + return ctx.end(); + } + + template + auto format(const cql3::restrictions::expression& expr, FormatContext& ctx) { + std::ostringstream os; + os << expr; + return format_to(ctx.out(), "{}", os.str()); + } +}; + +/// Required for fmt::join() to work on column_value. +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { + return ctx.end(); + } + + template + auto format(const cql3::restrictions::column_value& col, FormatContext& ctx) { + std::ostringstream os; + os << col; + return format_to(ctx.out(), "{}", os.str()); + } +}; diff --git a/cql3/restrictions/statement_restrictions.cc b/cql3/restrictions/statement_restrictions.cc index 6ae59a3d6f..894a8ceaf0 100644 --- a/cql3/restrictions/statement_restrictions.cc +++ b/cql3/restrictions/statement_restrictions.cc @@ -32,10 +32,13 @@ #include "token_restriction.hh" #include "database.hh" -#include "cql3/single_column_relation.hh" #include "cql3/constants.hh" -#include "types/map.hh" +#include "cql3/lists.hh" +#include "cql3/selection/selection.hh" +#include "cql3/single_column_relation.hh" +#include "cql3/tuples.hh" #include "types/list.hh" +#include "types/map.hh" #include "types/set.hh" namespace cql3 { @@ -1003,5 +1006,791 @@ void single_column_restriction::LIKE::merge_with(::shared_ptr rest) return r; } + +namespace { + +using children_t = std::vector; // conjunction's children. + +children_t explode_conjunction(expression e) { + return std::visit(overloaded_functor{ + [] (const conjunction& c) { return std::move(c.children); }, + [&] (const auto&) { return children_t{std::move(e)}; }, + }, e); +} + +using cql3::selection::selection; + +/// Serialized values for all types of cells, plus selection (to find a column's index) and options (for +/// subscript term's value). +struct row_data_from_partition_slice { + const std::vector& partition_key; + const std::vector& clustering_key; + const std::vector& other_columns; + const selection& sel; +}; + +/// Data used to derive cell values from a mutation. +struct row_data_from_mutation { + // Underscores avoid name clashes. + const partition_key& partition_key_; + const clustering_key_prefix& clustering_key_; + const row& other_columns; + const schema& schema_; + gc_clock::time_point now; +}; + +/// Everything needed to compute column values during restriction evaluation. +struct column_value_eval_bag { + const query_options& options; // For evaluating subscript terms. + std::variant row_data; +}; + +/// Returns col's value from queried data. +bytes_opt get_value_from_partition_slice( + const column_value& col, row_data_from_partition_slice data, const query_options& options) { + auto cdef = col.col; + if (col.sub) { + auto col_type = static_pointer_cast(cdef->type); + if (!col_type->is_map()) { + throw exceptions::invalid_request_exception(format("subscripting non-map column {}", cdef->name_as_text())); + } + const auto deserialized = cdef->type->deserialize(*data.other_columns[data.sel.index_of(*cdef)]); + const auto& data_map = value_cast(deserialized); + const auto key = col.sub->bind_and_get(options); + auto&& key_type = col_type->name_comparator(); + const auto found = with_linearized(*key, [&] (bytes_view key_bv) { + using entry = std::pair; + return std::find_if(data_map.cbegin(), data_map.cend(), [&] (const entry& element) { + return key_type->compare(element.first.serialize_nonnull(), key_bv) == 0; + }); + }); + return found == data_map.cend() ? bytes_opt() : bytes_opt(found->second.serialize_nonnull()); + } else { + switch (cdef->kind) { + case column_kind::partition_key: + return data.partition_key[cdef->id]; + case column_kind::clustering_key: + return data.clustering_key[cdef->id]; + case column_kind::static_column: + case column_kind::regular_column: + return data.other_columns[data.sel.index_of(*cdef)]; + default: + throw exceptions::unsupported_operation_exception("Unknown column kind"); + } + } +} + +/// Returns col's value from a mutation. +bytes_opt get_value_from_mutation(const column_value& col, row_data_from_mutation data) { + const auto v = do_get_value( + data.schema_, *col.col, data.partition_key_, data.clustering_key_, data.other_columns, data.now); + return v ? v->linearize() : bytes_opt(); +} + +/// Returns col's value from the fetched data. +bytes_opt get_value(const column_value& col, const column_value_eval_bag& bag) { + using std::placeholders::_1; + return std::visit(overloaded_functor{ + std::bind(get_value_from_mutation, col, _1), + std::bind(get_value_from_partition_slice, col, _1, bag.options), + }, bag.row_data); +} + +/// Type for comparing results of get_value(). +const abstract_type* get_value_comparator(const column_definition* cdef) { + return cdef->type->is_reversed() ? cdef->type->underlying_type().get() : cdef->type.get(); +} + +/// Type for comparing results of get_value(). +const abstract_type* get_value_comparator(const column_value& cv) { + return cv.sub ? static_pointer_cast(cv.col->type)->value_comparator().get() + : get_value_comparator(cv.col); +} + +/// If t represents a tuple value, returns that value. Otherwise, null. +/// +/// Useful for checking binary_operator::rhs, which packs multiple values into a single term when lhs is itself +/// a tuple. NOT useful for the IN operator, whose rhs is either a list or tuples::in_value. +::shared_ptr get_tuple(term& t, const query_options& opts) { + return dynamic_pointer_cast(t.bind(opts)); +} + +/// True iff lhs's value equals rhs. +bool equal(const bytes_opt& rhs, const column_value& lhs, const column_value_eval_bag& bag) { + if (!rhs) { + return false; + } + const auto value = get_value(lhs, bag); + if (!value) { + return false; + } + return get_value_comparator(lhs)->equal(*value, *rhs); +} + +/// True iff columns' values equal t. +bool equal(::shared_ptr t, const std::vector& columns, const column_value_eval_bag& bag) { + if (columns.size() > 1) { + const auto tup = get_tuple(*t, bag.options); + if (!tup) { + throw exceptions::invalid_request_exception("multi-column equality has right-hand side that isn't a tuple"); + } + const auto& rhs = tup->get_elements(); + if (rhs.size() != columns.size()) { + throw exceptions::invalid_request_exception( + format("tuple equality size mismatch: {} elements on left-hand side, {} on right", + columns.size(), rhs.size())); + } + return boost::equal(rhs, columns, [&] (const bytes_opt& rhs, const column_value& lhs) { + return equal(rhs, lhs, bag); + }); + } else if (columns.size() == 1) { + const auto tup = get_tuple(*t, bag.options); + if (tup && tup->size() == 1) { + // Assume this is an external query WHERE (ck1)=(123), rather than an internal query WHERE + // col=(123), because internal queries have no reason to use single-element tuples. + // + // TODO: make the two cases distinguishable. + return equal(tup->get_elements()[0], columns[0], bag); + } + return equal(to_bytes_opt(t->bind_and_get(bag.options)), columns[0], bag); + } else { + throw std::logic_error("empty tuple on LHS of ="); + } +} + +/// True iff lhs is limited by rhs in the manner prescribed by op. +bool limits(bytes_view lhs, const operator_type& op, bytes_view rhs, const abstract_type& type) { + if (!op.is_compare()) { + throw std::logic_error("limits() called on non-compare op"); + } + const auto cmp = type.compare(lhs, rhs); + if (cmp < 0) { + return op == operator_type::LT || op == operator_type::LTE || op == operator_type::NEQ; + } else if (cmp > 0) { + return op == operator_type::GT || op == operator_type::GTE || op == operator_type::NEQ; + } else { + return op == operator_type::LTE || op == operator_type::GTE || op == operator_type::EQ; + } +} + +/// True iff the value of opr.lhs (which must be column_values) is limited by opr.rhs in the manner prescribed +/// by opr.op. +bool limits(const binary_operator& opr, const column_value_eval_bag& bag) { + if (!opr.op->is_slice()) { // For EQ or NEQ, use equal(). + throw std::logic_error("limits() called on non-slice op"); + } + const auto& columns = std::get<0>(opr.lhs); + if (columns.size() > 1) { + const auto tup = get_tuple(*opr.rhs, bag.options); + if (!tup) { + throw exceptions::invalid_request_exception("multi-column comparison has right-hand side that isn't a tuple"); + } + const auto& rhs = tup->get_elements(); + if (rhs.size() != columns.size()) { + throw exceptions::invalid_request_exception( + format("tuple comparison size mismatch: {} elements on left-hand side, {} on right", + columns.size(), rhs.size())); + } + for (size_t i = 0; i < rhs.size(); ++i) { + const auto cmp = get_value_comparator(columns[i])->compare( + // CQL dictates that columns[i] is a clustering column and non-null. + *get_value(columns[i], bag), + *rhs[i]); + // If the components aren't equal, then we just learned the LHS/RHS order. + if (cmp < 0) { + if (*opr.op == operator_type::LT || *opr.op == operator_type::LTE) { + return true; + } else if (*opr.op == operator_type::GT || *opr.op == operator_type::GTE) { + return false; + } else { + throw std::logic_error("Unknown slice operator"); + } + } else if (cmp > 0) { + if (*opr.op == operator_type::LT || *opr.op == operator_type::LTE) { + return false; + } else if (*opr.op == operator_type::GT || *opr.op == operator_type::GTE) { + return true; + } else { + throw std::logic_error("Unknown slice operator"); + } + } + // Otherwise, we don't know the LHS/RHS order, so check the next component. + } + // Getting here means LHS == RHS. + return *opr.op == operator_type::LTE || *opr.op == operator_type::GTE; + } else if (columns.size() == 1) { + auto lhs = get_value(columns[0], bag); + if (!lhs) { + lhs = bytes(); // Compatible with old code, which feeds null to type comparators. + } + const auto tup = get_tuple(*opr.rhs, bag.options); + auto rhs = (tup && tup->size() == 1) ? tup->get_elements()[0] // Assume an external query WHERE (ck1)>(123). + : to_bytes_opt(opr.rhs->bind_and_get(bag.options)); + if (!rhs) { + return false; + } + return limits(*lhs, *opr.op, *rhs, *get_value_comparator(columns[0])); + } else { + throw std::logic_error("empty tuple on LHS of an inequality"); + } +} + +/// True iff collection (list, set, or map) contains value. +bool contains(const data_value& collection, const raw_value_view& value) { + if (!value) { + return true; // Compatible with old code, which skips null terms in value comparisons. + } + auto col_type = static_pointer_cast(collection.type()); + auto&& element_type = col_type->is_set() ? col_type->name_comparator() : col_type->value_comparator(); + return with_linearized(*value, [&] (bytes_view val) { + auto exists_in = [&](auto&& range) { + auto found = std::find_if(range.begin(), range.end(), [&] (auto&& element) { + return element_type->compare(element.serialize_nonnull(), val) == 0; + }); + return found != range.end(); + }; + if (col_type->is_list()) { + return exists_in(value_cast(collection)); + } else if (col_type->is_set()) { + return exists_in(value_cast(collection)); + } else if (col_type->is_map()) { + auto data_map = value_cast(collection); + using entry = std::pair; + return exists_in(data_map | transformed([] (const entry& e) { return e.second; })); + } else { + throw std::logic_error("unsupported collection type in a CONTAINS expression"); + } + }); +} + +/// True iff columns is a single collection containing value. +bool contains(const raw_value_view& value, const std::vector& columns, const column_value_eval_bag& bag) { + if (columns.size() != 1) { + throw exceptions::unsupported_operation_exception("tuple CONTAINS not allowed"); + } + if (columns[0].sub) { + throw exceptions::unsupported_operation_exception("CONTAINS lhs is subscripted"); + } + const auto collection = get_value(columns[0], bag); + if (collection) { + return contains(columns[0].col->type->deserialize(*collection), value); + } else { + return false; + } +} + +/// True iff \p columns has a single element that's a map containing \p key. +bool contains_key(const std::vector& columns, cql3::raw_value_view key, const column_value_eval_bag& bag) { + if (columns.size() != 1) { + throw exceptions::unsupported_operation_exception("CONTAINS KEY on a tuple"); + } + if (columns[0].sub) { + throw exceptions::unsupported_operation_exception("CONTAINS KEY lhs is subscripted"); + } + if (!key) { + return true; // Compatible with old code, which skips null terms in key comparisons. + } + auto cdef = columns[0].col; + const auto collection = get_value(columns[0], bag); + if (!collection) { + return false; + } + const auto data_map = value_cast(cdef->type->deserialize(*collection)); + auto key_type = static_pointer_cast(cdef->type)->name_comparator(); + auto found = with_linearized(*key, [&] (bytes_view k_bv) { + using entry = std::pair; + return std::find_if(data_map.begin(), data_map.end(), [&] (const entry& element) { + return key_type->compare(element.first.serialize_nonnull(), k_bv) == 0; + }); + }); + return found != data_map.end(); +} + +/// Fetches the next cell value from iter and returns its (possibly null) value. +bytes_opt next_value(query::result_row_view::iterator_type& iter, const column_definition* cdef) { + if (cdef->type->is_multi_cell()) { + auto cell = iter.next_collection_cell(); + if (cell) { + return cell->with_linearized([] (bytes_view data) { + return bytes(data.cbegin(), data.cend()); + }); + } + } else { + auto cell = iter.next_atomic_cell(); + if (cell) { + return cell->value().with_linearized([] (bytes_view data) { + return bytes(data.cbegin(), data.cend()); + }); + } + } + return std::nullopt; +} + +/// Returns values of non-primary-key columns from selection. The kth element of the result +/// corresponds to the kth column in selection. +std::vector get_non_pk_values(const selection& selection, const query::result_row_view& static_row, + const query::result_row_view* row) { + const auto& cols = selection.get_columns(); + std::vector vals(cols.size()); + auto static_row_iterator = static_row.iterator(); + auto row_iterator = row ? std::optional(row->iterator()) : std::nullopt; + for (size_t i = 0; i < cols.size(); ++i) { + switch (cols[i]->kind) { + case column_kind::static_column: + vals[i] = next_value(static_row_iterator, cols[i]); + break; + case column_kind::regular_column: + if (row) { + vals[i] = next_value(*row_iterator, cols[i]); + } + break; + default: // Skip. + break; + } + } + return vals; +} + +/// True iff cv matches the CQL LIKE pattern. +bool like(const column_value& cv, const bytes_opt& pattern, const column_value_eval_bag& bag) { + if (!cv.col->type->is_string()) { + throw exceptions::invalid_request_exception( + format("LIKE is allowed only on string types, which {} is not", cv.col->name_as_text())); + } + auto value = get_value(cv, bag); + // TODO: reuse matchers. + return (pattern && value) ? like_matcher(*pattern)(*value) : false; +} + +/// True iff columns' values match rhs pattern(s) as defined by CQL LIKE. +bool like(const std::vector& columns, term& rhs, const column_value_eval_bag& bag) { + if (columns.size() > 1) { + if (const auto tup = get_tuple(rhs, bag.options)) { + const auto& elements = tup->get_elements(); + if (elements.size() != columns.size()) { + throw exceptions::invalid_request_exception( + format("LIKE tuple size mismatch: {} elements on left-hand side, {} on right", + columns.size(), elements.size())); + } + return boost::equal(columns, elements, [&] (const column_value& cv, const bytes_opt& pattern) { + return like(cv, pattern, bag); + }); + } else { + throw exceptions::invalid_request_exception("multi-column LIKE has right-hand side that isn't a tuple"); + } + } else if (columns.size() == 1) { + return like(columns[0], to_bytes_opt(rhs.bind_and_get(bag.options)), bag); + } else { + throw exceptions::invalid_request_exception("empty tuple on left-hand side of LIKE"); + } +} + +/// True iff the tuple of column values is in the set defined by rhs. +bool is_one_of(const std::vector& cvs, term& rhs, const column_value_eval_bag& bag) { + // RHS is prepared differently for different CQL cases. Cast it dynamically to discern which case this is. + if (auto dv = dynamic_cast(&rhs)) { + // This is either `a IN (1,2,3)` or `(a,b) IN ((1,1),(2,2),(3,3))`. RHS elements are themselves terms. + return boost::algorithm::any_of(dv->get_elements(), [&] (const ::shared_ptr& t) { + return equal(t, cvs, bag); + }); + } else if (auto mkr = dynamic_cast(&rhs)) { + // This is `a IN ?`. RHS elements are values representable as bytes_opt. + if (cvs.size() != 1) { + throw std::logic_error("too many columns for lists::marker in is_one_of"); + } + const auto values = static_pointer_cast(mkr->bind(bag.options)); + return boost::algorithm::any_of(values->get_elements(), [&] (const bytes_opt& b) { + return equal(b, cvs[0], bag); + }); + } else if (auto mkr = dynamic_cast(&rhs)) { + // This is `(a,b) IN ?`. RHS elements are themselves tuples, represented as vector. + const auto marker_value = static_pointer_cast(mkr->bind(bag.options)); + return boost::algorithm::any_of(marker_value->get_split_values(), [&] (const std::vector& el) { + return boost::equal(cvs, el, [&] (const column_value& c, const bytes_opt& b) { + return equal(b, c, bag); + }); + }); + } + throw std::logic_error("unexpected term type in is_one_of"); +} + +/// True iff op means bnd type of bound. +bool matches(const operator_type* op, statements::bound bnd) { + static const std::vector> operators{ + {&operator_type::EQ, &operator_type::GT, &operator_type::GTE}, // These mean a lower bound. + {&operator_type::EQ, &operator_type::LT, &operator_type::LTE}, // These mean an upper bound. + }; + const auto zero_if_lower_one_if_upper = get_idx(bnd); + return boost::algorithm::any_of_equal(operators[zero_if_lower_one_if_upper], op); +} + +const value_set empty_value_set = value_list{}; +const value_set unbounded_value_set = nonwrapping_range::make_open_ended_both_sides(); + +struct intersection_visitor { + const abstract_type* type; + value_set operator()(const value_list& a, const value_list& b) const { + value_list common; + common.reserve(std::max(a.size(), b.size())); + boost::set_intersection(a, b, back_inserter(common), type->as_less_comparator()); + return std::move(common); + } + + value_set operator()(const nonwrapping_range& a, const value_list& b) const { + const auto common = b | filtered([&] (const bytes& el) { return a.contains(el, type->as_tri_comparator()); }); + return value_list(common.begin(), common.end()); + } + + value_set operator()(const value_list& a, const nonwrapping_range& b) const { + return (*this)(b, a); + } + + value_set operator()(const nonwrapping_range& a, const nonwrapping_range& b) const { + const auto common_range = a.intersection(b, type->as_tri_comparator()); + return common_range ? *common_range : empty_value_set; + } +}; + +value_set intersection(value_set a, value_set b, const abstract_type* type) { + return std::visit(intersection_visitor{type}, std::move(a), std::move(b)); +} + +bool is_satisfied_by(const binary_operator& opr, const column_value_eval_bag& bag) { + return std::visit(overloaded_functor{ + [&] (const std::vector& cvs) { + if (*opr.op == operator_type::EQ) { + return equal(opr.rhs, cvs, bag); + } else if (*opr.op == operator_type::NEQ) { + return !equal(opr.rhs, cvs, bag); + } else if (opr.op->is_slice()) { + return limits(opr, bag); + } else if (*opr.op == operator_type::CONTAINS) { + return contains(opr.rhs->bind_and_get(bag.options), cvs, bag); + } else if (*opr.op == operator_type::CONTAINS_KEY) { + return contains_key(cvs, opr.rhs->bind_and_get(bag.options), bag); + } else if (*opr.op == operator_type::LIKE) { + return like(cvs, *opr.rhs, bag); + } else if (*opr.op == operator_type::IN) { + return is_one_of(cvs, *opr.rhs, bag); + } else { + throw exceptions::unsupported_operation_exception("Unhandled binary_operator"); + } + }, + [] (const token& tok) -> bool { + // The RHS value was already used to ensure we fetch only rows in the specified + // token range. It is impossible for any fetched row not to match now. + return true; + }, + }, opr.lhs); +} + +bool is_satisfied_by(const expression& restr, const column_value_eval_bag& bag) { + return std::visit(overloaded_functor{ + [&] (bool v) { return v; }, + [&] (const conjunction& conj) { + return boost::algorithm::all_of(conj.children, [&] (const expression& c) { + return is_satisfied_by(c, bag); + }); + }, + [&] (const binary_operator& opr) { return is_satisfied_by(opr, bag); }, + }, restr); +} + +/// If t is a tuple, binds and gets its k-th element. Otherwise, binds and gets t's whole value. +bytes_opt get_kth(size_t k, const query_options& options, const ::shared_ptr& t) { + auto bound = t->bind(options); + if (auto tup = dynamic_pointer_cast(bound)) { + return tup->get_elements()[k]; + } else { + assert(k == 0 && "non-tuple RHS for multi-column IN"); + return to_bytes_opt(bound->get(options)); + } +} + +template +value_list to_sorted_vector(const Range& r, const serialized_compare& comparator) { + value_list tmp(r.begin(), r.end()); // Need random-access range to sort (r is not necessarily random-access). + const auto unique = boost::unique(boost::sort(tmp, comparator)); + return value_list(unique.begin(), unique.end()); +} + +/// Returns possible values for k-th column from t, which must be RHS of IN. +value_list get_IN_values(const ::shared_ptr& t, size_t k, const query_options& options, + const serialized_compare& comparator) { + const auto non_null = filtered([] (const bytes_opt& b) { return b.has_value(); }); + const auto deref = transformed([] (const bytes_opt& b) { return b.value(); }); + // RHS is prepared differently for different CQL cases. Cast it dynamically to discern which case this is. + if (auto dv = dynamic_pointer_cast(t)) { + // Case `a IN (1,2,3)` or `(a,b) in ((1,1),(2,2),(3,3)). Get kth value from each term element. + const auto result_range = dv->get_elements() + | transformed(std::bind_front(get_kth, k, options)) | non_null | deref; + return to_sorted_vector(result_range, comparator); + } else if (auto mkr = dynamic_pointer_cast(t)) { + // Case `a IN ?`. Collect all list-element values. + assert(k == 0 && "lists::marker is for single-column IN"); + const auto val = static_pointer_cast(mkr->bind(options)); + return to_sorted_vector(val->get_elements() | non_null | deref, comparator); + } else if (auto mkr = dynamic_pointer_cast(t)) { + // Case `(a,b) IN ?`. Get kth value from each vector element. + const auto val = static_pointer_cast(mkr->bind(options)); + const auto result_range = val->get_split_values() + | transformed([k] (const std::vector& v) { return v[k]; }) | non_null | deref; + return to_sorted_vector(result_range, comparator); + } + throw std::logic_error(format("get_IN_values on invalid term {}", *t)); +} + +} // anonymous namespace + +expression make_conjunction(expression a, expression b) { + auto children = explode_conjunction(std::move(a)); + boost::copy(explode_conjunction(std::move(b)), back_inserter(children)); + return conjunction{std::move(children)}; +} + +bool is_satisfied_by( + const expression& restr, + const std::vector& partition_key, const std::vector& clustering_key, + const query::result_row_view& static_row, const query::result_row_view* row, + const selection& selection, const query_options& options) { + const auto regulars = get_non_pk_values(selection, static_row, row); + return is_satisfied_by( + restr, {options, row_data_from_partition_slice{partition_key, clustering_key, regulars, selection}}); +} + +bool is_satisfied_by( + const expression& restr, + const schema& schema, const partition_key& key, const clustering_key_prefix& ckey, const row& cells, + const query_options& options, gc_clock::time_point now) { + return is_satisfied_by(restr, {options, row_data_from_mutation{key, ckey, cells, schema, now}}); +} + +std::vector first_multicolumn_bound( + const expression& restr, const query_options& options, statements::bound bnd) { + auto found = find_if(restr, [bnd] (const binary_operator& oper) { + return matches(oper.op, bnd) && std::holds_alternative>(oper.lhs); + }); + if (found) { + return static_pointer_cast(found->rhs->bind(options))->get_elements(); + } else { + return std::vector{}; + } +} + +value_set possible_lhs_values(const column_definition* cdef, const expression& expr, const query_options& options) { + const auto type = cdef ? get_value_comparator(cdef) : long_type.get(); + return std::visit(overloaded_functor{ + [] (bool b) { + return b ? unbounded_value_set : empty_value_set; + }, + [&] (const conjunction& conj) { + return boost::accumulate(conj.children, unbounded_value_set, + [&] (const value_set& acc, const expression& child) { + return intersection( + std::move(acc), possible_lhs_values(cdef, child, options), type); + }); + }, + [&] (const binary_operator& oper) -> value_set { + static constexpr bool inclusive = true, exclusive = false; + return std::visit(overloaded_functor{ + [&] (const std::vector& cvs) -> value_set { + if (!cdef) { + return unbounded_value_set; + } + const auto found = boost::find_if( + cvs, [&] (const column_value& c) { return c.col == cdef; }); + if (found == cvs.end()) { + return unbounded_value_set; + } + const auto column_index_on_lhs = std::distance(cvs.begin(), found); + if (oper.op->is_compare()) { + const auto tup = get_tuple(*oper.rhs, options); + bytes_opt val = tup ? tup->get_elements()[column_index_on_lhs] + : to_bytes_opt(oper.rhs->bind_and_get(options)); + if (!val) { + return empty_value_set; // All NULL comparisons fail; no column values match. + } + if (*oper.op == operator_type::EQ) { + return value_list{*val}; + } + if (column_index_on_lhs > 0) { + // A multi-column comparison restricts only the first column, because + // comparison is lexicographical. + return unbounded_value_set; + } + if (*oper.op == operator_type::GT) { + return nonwrapping_range::make_starting_with(range_bound(*val, exclusive)); + } else if (*oper.op == operator_type::GTE) { + return nonwrapping_range::make_starting_with(range_bound(*val, inclusive)); + } else if (*oper.op == operator_type::LT) { + return nonwrapping_range::make_ending_with(range_bound(*val, exclusive)); + } else if (*oper.op == operator_type::LTE) { + return nonwrapping_range::make_ending_with(range_bound(*val, inclusive)); + } + throw std::logic_error( + format("get_column_interval unknown comparison operator {}", *oper.op)); + } else if (*oper.op == operator_type::IN) { + return get_IN_values(oper.rhs, column_index_on_lhs, options, type->as_less_comparator()); + } + return unbounded_value_set; + }, + [&] (token) -> value_set { + if (cdef) { + return unbounded_value_set; + } + const auto val = to_bytes_opt(oper.rhs->bind_and_get(options)); + if (!val) { + return empty_value_set; // All NULL comparisons fail; no token values match. + } + if (*oper.op == operator_type::EQ) { + return value_list{*val}; + } else if (*oper.op == operator_type::GT) { + return nonwrapping_range::make_starting_with(range_bound(*val, exclusive)); + } else if (*oper.op == operator_type::GTE) { + return nonwrapping_range::make_starting_with(range_bound(*val, inclusive)); + } + static const bytes MININT = serialized(std::numeric_limits::min()), + MAXINT = serialized(std::numeric_limits::max()); + // Undocumented feature: when the user types `token(...) < MININT`, we interpret + // that as MAXINT for some reason. + const auto adjusted_val = (*val == MININT) ? serialized(MAXINT) : *val; + if (*oper.op == operator_type::LT) { + return nonwrapping_range::make_ending_with(range_bound(adjusted_val, exclusive)); + } else if (*oper.op == operator_type::LTE) { + return nonwrapping_range::make_ending_with(range_bound(adjusted_val, inclusive)); + } + throw std::logic_error(format("get_token_interval invalid operator {}", *oper.op)); + }, + }, oper.lhs); + }, + }, expr); +} + +nonwrapping_range to_range(const value_set& s) { + return std::visit(overloaded_functor{ + [] (const nonwrapping_range& r) { return r; }, + [] (const value_list& lst) { + if (lst.size() != 1) { + throw std::logic_error(format("to_range called on list of size {}", lst.size())); + } + return nonwrapping_range::make_singular(lst[0]); + }, + }, s); +} + +bool uses_function(const expression& expr, const sstring& ks_name, const sstring& function_name) { + return std::visit(overloaded_functor{ + [&] (const conjunction& conj) { + using std::placeholders::_1; + return boost::algorithm::any_of(conj.children, std::bind(uses_function, _1, ks_name, function_name)); + }, + [&] (const binary_operator& oper) { + if (oper.rhs && oper.rhs->uses_function(ks_name, function_name)) { + return true; + } else if (auto columns = std::get_if>(&oper.lhs)) { + return boost::algorithm::any_of(*columns, [&] (const column_value& cv) { + return cv.sub && cv.sub->uses_function(ks_name, function_name); + }); + } + return false; + }, + [&] (const auto& default_case) { return false; }, + }, expr); +} + +bool is_supported_by(const expression& expr, const secondary_index::index& idx) { + using std::placeholders::_1; + return std::visit(overloaded_functor{ + [&] (const conjunction& conj) { + return boost::algorithm::all_of(conj.children, std::bind(is_supported_by, _1, idx)); + }, + [&] (const binary_operator& oper) { + if (auto cvs = std::get_if>(&oper.lhs)) { + return boost::algorithm::any_of(*cvs, [&] (const column_value& c) { + return idx.supports_expression(*c.col, *oper.op); + }); + } + return false; + }, + [] (const auto& default_case) { return false; } + }, expr); +} + +bool has_supporting_index( + const expression& expr, + const secondary_index::secondary_index_manager& index_manager, + allow_local_index allow_local) { + const auto indexes = index_manager.list_indexes(); + const auto support = std::bind(is_supported_by, expr, std::placeholders::_1); + return allow_local ? boost::algorithm::any_of(indexes, support) + : boost::algorithm::any_of( + indexes | filtered([] (const secondary_index::index& i) { return !i.metadata().local(); }), + support); +} + +std::ostream& operator<<(std::ostream& os, const column_value& cv) { + os << *cv.col; + if (cv.sub) { + os << '[' << *cv.sub << ']'; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const expression& expr) { + std::visit(overloaded_functor{ + [&] (bool b) { os << (b ? "TRUE" : "FALSE"); }, + [&] (const conjunction& conj) { fmt::print(os, "({})", fmt::join(conj.children, ") AND (")); }, + [&] (const binary_operator& opr) { + std::visit(overloaded_functor{ + [&] (const token& t) { os << "TOKEN"; }, + [&] (const std::vector& cvs) { + const bool multi = cvs.size() != 1; + os << (multi ? "(" : ""); + fmt::print(os, "({})", fmt::join(cvs, ",")); + os << (multi ? ")" : ""); + }, + }, opr.lhs); + os << ' ' << *opr.op << ' ' << *opr.rhs; + }, + }, expr); + return os; +} + +sstring to_string(const expression& expr) { + return fmt::format("{}", expr); +} + +bool is_on_collection(const binary_operator& b) { + if (*b.op == operator_type::CONTAINS || *b.op == operator_type::CONTAINS_KEY) { + return true; + } + if (auto cvs = std::get_if>(&b.lhs)) { + return boost::algorithm::any_of(*cvs, [] (const column_value& v) { return v.sub; }); + } + return false; +} + +expression replace_column_def(const expression& expr, const column_definition* new_cdef) { + return std::visit(overloaded_functor{ + [] (bool b){ return expression(b); }, + [&] (const conjunction& conj) { + const auto applied = conj.children | transformed( + std::bind(replace_column_def, std::placeholders::_1, new_cdef)); + return expression(conjunction{std::vector(applied.begin(), applied.end())}); + }, + [&] (const binary_operator& oper) { + return std::visit(overloaded_functor{ + [&] (const std::vector& cvs) { + if (cvs.size() != 1) { + throw std::logic_error(format("replace_column_def invalid LHS: {}", to_string(oper))); + } + return expression(binary_operator{std::vector{column_value{new_cdef}}, oper.op, oper.rhs}); + }, + [&] (const token&) { return expr; }, + }, oper.lhs); + }, + }, expr); +} + } // namespace restrictions } // namespace cql3 diff --git a/cql3/tuples.hh b/cql3/tuples.hh index 3c4d5cc62d..308fd1067a 100644 --- a/cql3/tuples.hh +++ b/cql3/tuples.hh @@ -127,6 +127,9 @@ public: virtual const std::vector& get_elements() const override { return _elements; } + size_t size() const { + return _elements.size(); + } virtual sstring to_string() const override { return format("({})", join(", ", _elements)); }