From d69bf4f0101db3a3c2df53e0f35e8534c629bfc2 Mon Sep 17 00:00:00 2001 From: Avi Kivity Date: Sun, 28 Jul 2024 15:20:16 +0300 Subject: [PATCH] cql3: introduce dialect infrastructure A dialect is a different way to interpret the same CQL statement. Examples: - how duplicate bind variable names are handled (later in this series) - whether `column = NULL` in LWT can return true (as is now) or whether it always returns NULL (as in SQL) Currently, dialect is an empty structure and will be filled in later. It is passed to query_processor methods that also accept a CQL string, and from there to the parser. It is part of the prepared statement cache key, so that if the dialect is changed online, previous parses of the statement are ignored and the statement is prepared again. The patch is careful to pick up the dialect at the entry point (e.g. CQL protocol server) so that the dialect doesn't change while a statement is parsed, prepared, and cached. --- auth/common.cc | 2 +- cql3/Cql.g | 7 ++++ cql3/cql3_type.cc | 3 +- cql3/dialect.hh | 31 +++++++++++++++ cql3/prepared_statements_cache.hh | 11 ++++-- cql3/query_processor.cc | 37 ++++++++++-------- cql3/query_processor.hh | 25 +++++++----- cql3/statements/alter_table_statement.cc | 3 +- cql3/statements/select_statement.cc | 4 +- cql3/util.cc | 23 +++++++----- cql3/util.hh | 11 +++--- db/cql_type_parser.cc | 4 +- db/schema_tables.cc | 4 +- table_helper.cc | 10 ++--- table_helper.hh | 3 +- test/boost/cql_auth_syntax_test.cc | 6 +-- test/boost/expr_test.cc | 2 +- test/boost/statement_restrictions_test.cc | 2 +- test/lib/cql_test_env.cc | 17 ++++++--- tools/schema_loader.cc | 5 ++- transport/server.cc | 46 ++++++++++++++--------- transport/server.hh | 5 ++- 22 files changed, 174 insertions(+), 87 deletions(-) create mode 100644 cql3/dialect.hh diff --git a/auth/common.cc b/auth/common.cc index 1cf3a6a200..7e0b94ef19 100644 --- a/auth/common.cc +++ b/auth/common.cc @@ -72,7 +72,7 @@ static future<> create_legacy_metadata_table_if_missing_impl( SCYLLA_ASSERT(this_shard_id() == 0); // once_among_shards makes sure a function is executed on shard 0 only auto db = qp.db(); - auto parsed_statement = cql3::query_processor::parse_statement(cql); + auto parsed_statement = cql3::query_processor::parse_statement(cql, cql3::dialect{}); auto& parsed_cf_statement = static_cast(*parsed_statement); parsed_cf_statement.prepare_keyspace(meta::legacy::AUTH_KS); diff --git a/cql3/Cql.g b/cql3/Cql.g index b99e97250a..39341b28f3 100644 --- a/cql3/Cql.g +++ b/cql3/Cql.g @@ -68,6 +68,7 @@ options { #include "cql3/statements/ks_prop_defs.hh" #include "cql3/selection/raw_selector.hh" #include "cql3/selection/selectable-expr.hh" +#include "cql3/dialect.hh" #include "cql3/keyspace_element_name.hh" #include "cql3/constants.hh" #include "cql3/operation_impl.hh" @@ -148,6 +149,8 @@ using uexpression = uninitialized; listener_type* listener; + dialect _dialect; + // Keeps the names of all bind variables. For bind variables without a name ('?'), the name is nullptr. // Maps bind_index -> name. std::vector<::shared_ptr> _bind_variable_names; @@ -171,6 +174,10 @@ using uexpression = uninitialized; return s; } + void set_dialect(dialect d) { + _dialect = d; + } + bind_variable new_bind_variables(shared_ptr name) { if (name && _named_bind_variables_indexes.contains(*name)) { diff --git a/cql3/cql3_type.cc b/cql3/cql3_type.cc index ff9bbbb559..115b12655e 100644 --- a/cql3/cql3_type.cc +++ b/cql3/cql3_type.cc @@ -449,7 +449,8 @@ sstring maybe_quote(const sstring& identifier) { // many keywords but allow keywords listed as "unreserved keywords". // So we can use any of them, for example cident. try { - cql3::util::do_with_parser(identifier, std::mem_fn(&cql3_parser::CqlParser::cident)); + // In general it's not a good idea to use the default dialect, but for parsing an identifier, it's okay. + cql3::util::do_with_parser(identifier, dialect{}, std::mem_fn(&cql3_parser::CqlParser::cident)); return identifier; } catch(exceptions::syntax_exception&) { // This alphanumeric string is not a valid identifier, so fall diff --git a/cql3/dialect.hh b/cql3/dialect.hh new file mode 100644 index 0000000000..a714426ba0 --- /dev/null +++ b/cql3/dialect.hh @@ -0,0 +1,31 @@ +// Copyright (C) 2024-present ScyllaDB +// SPDX-License-Identifier: AGPL-3.0-or-later + +#pragma once + +#include + +namespace cql3 { + +struct dialect { + bool operator==(const dialect&) const = default; +}; + +inline +dialect +internal_dialect() { + return dialect{ + }; +} + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + + template + auto format(const cql3::dialect& d, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "cql3::dialect{{}}"); + } +}; diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index 796dc918f1..7b041265b3 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -14,6 +14,7 @@ #include "utils/hash.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/column_specification.hh" +#include "cql3/dialect.hh" namespace cql3 { @@ -37,13 +38,17 @@ class prepared_cache_key_type { public: // derive from cql_prepared_id_type so we can customize the formatter of // cache_key_type - struct cache_key_type : public cql_prepared_id_type {}; + struct cache_key_type : public cql_prepared_id_type { + cache_key_type(cql_prepared_id_type&& id, cql3::dialect d) : cql_prepared_id_type(std::move(id)), dialect(d) {} + cql3::dialect dialect; // Not part of hash, but we don't expect collisions because of that + bool operator==(const cache_key_type& other) const = default; + }; private: cache_key_type _key; public: - explicit prepared_cache_key_type(cql_prepared_id_type cql_id) : _key(std::move(cql_id)) {} + explicit prepared_cache_key_type(cql_prepared_id_type cql_id, dialect d) : _key(std::move(cql_id), d) {} cache_key_type& key() { return _key; } const cache_key_type& key() const { return _key; } @@ -175,7 +180,7 @@ struct hash final { template <> struct fmt::formatter { constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } auto format(const cql3::prepared_cache_key_type::cache_key_type& p, fmt::format_context& ctx) const { - return fmt::format_to(ctx.out(), "{{cql_id: {}}}", static_cast(p)); + return fmt::format_to(ctx.out(), "{{cql_id: {}, dialect: {}}}", static_cast(p), p.dialect); } }; diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 4ab775d45e..50f8e00c74 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -565,10 +565,10 @@ query_processor::execute_maybe_with_guard(service::query_state& query_state, ::s } future<::shared_ptr> -query_processor::execute_direct_without_checking_exception_message(const sstring_view& query_string, service::query_state& query_state, query_options& options) { +query_processor::execute_direct_without_checking_exception_message(const sstring_view& query_string, service::query_state& query_state, dialect d, query_options& options) { log.trace("execute_direct: \"{}\"", query_string); tracing::trace(query_state.get_trace_state(), "Parsing a statement"); - auto p = get_statement(query_string, query_state.get_client_state()); + auto p = get_statement(query_string, query_state.get_client_state(), d); auto statement = p->statement; const auto warnings = std::move(p->warnings); if (statement->get_bound_terms() != options.get_values_count()) { @@ -652,18 +652,21 @@ query_processor::process_authorized_statement(const ::shared_ptr } future<::shared_ptr> -query_processor::prepare(sstring query_string, service::query_state& query_state) { +query_processor::prepare(sstring query_string, service::query_state& query_state, cql3::dialect d) { auto& client_state = query_state.get_client_state(); - return prepare(std::move(query_string), client_state); + return prepare(std::move(query_string), client_state, d); } future<::shared_ptr> -query_processor::prepare(sstring query_string, const service::client_state& client_state) { +query_processor::prepare(sstring query_string, const service::client_state& client_state, cql3::dialect d) { using namespace cql_transport::messages; return prepare_one( std::move(query_string), client_state, - compute_id, + d, + [d] (std::string_view query_string, std::string_view keyspace) { + return compute_id(query_string, keyspace, d); + }, prepared_cache_key_type::cql_id); } @@ -675,13 +678,14 @@ static std::string hash_target(std::string_view query_string, std::string_view k prepared_cache_key_type query_processor::compute_id( std::string_view query_string, - std::string_view keyspace) { - return prepared_cache_key_type(md5_hasher::calculate(hash_target(query_string, keyspace))); + std::string_view keyspace, + dialect d) { + return prepared_cache_key_type(md5_hasher::calculate(hash_target(query_string, keyspace)), d); } std::unique_ptr -query_processor::get_statement(const sstring_view& query, const service::client_state& client_state) { - std::unique_ptr statement = parse_statement(query); +query_processor::get_statement(const sstring_view& query, const service::client_state& client_state, dialect d) { + std::unique_ptr statement = parse_statement(query, d); // Set keyspace for statement that require login auto cf_stmt = dynamic_cast(statement.get()); @@ -695,7 +699,7 @@ query_processor::get_statement(const sstring_view& query, const service::client_ } std::unique_ptr -query_processor::parse_statement(const sstring_view& query) { +query_processor::parse_statement(const sstring_view& query, dialect d) { try { { const char* error_injection_key = "query_processor-parse_statement-test_failure"; @@ -705,7 +709,7 @@ query_processor::parse_statement(const sstring_view& query) { } }); } - auto statement = util::do_with_parser(query, std::mem_fn(&cql3_parser::CqlParser::query)); + auto statement = util::do_with_parser(query, d, std::mem_fn(&cql3_parser::CqlParser::query)); if (!statement) { throw exceptions::syntax_exception("Parsing failed"); } @@ -721,9 +725,9 @@ query_processor::parse_statement(const sstring_view& query) { } std::vector> -query_processor::parse_statements(std::string_view queries) { +query_processor::parse_statements(std::string_view queries, dialect d) { try { - auto statements = util::do_with_parser(queries, std::mem_fn(&cql3_parser::CqlParser::queries)); + auto statements = util::do_with_parser(queries, d, std::mem_fn(&cql3_parser::CqlParser::queries)); if (statements.empty()) { throw exceptions::syntax_exception("Parsing failed"); } @@ -796,7 +800,7 @@ query_options query_processor::make_internal_options( statements::prepared_statement::checked_weak_ptr query_processor::prepare_internal(const sstring& query_string) { auto& p = _internal_statements[query_string]; if (p == nullptr) { - auto np = parse_statement(query_string)->prepare(_db, _cql_stats); + auto np = parse_statement(query_string, internal_dialect())->prepare(_db, _cql_stats); np->statement->raw_cql_statement = query_string; p = std::move(np); // inserts it into map } @@ -902,7 +906,8 @@ query_processor::execute_internal( auto p = prepare_internal(query_string); return execute_with_params(std::move(p), cl, query_state, values); } else { - auto p = parse_statement(query_string)->prepare(_db, _cql_stats); + // For internal queries, we want the default dialect, not the user provided one + auto p = parse_statement(query_string, dialect{})->prepare(_db, _cql_stats); p->statement->raw_cql_statement = query_string; auto checked_weak_ptr = p->checked_weak_from_this(); return execute_with_params(std::move(checked_weak_ptr), cl, query_state, values).finally([p = std::move(p)] {}); diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index b0e6d4fb6a..2677ad004e 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -21,6 +21,7 @@ #include "cql3/authorized_prepared_statements_cache.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/cql_statement.hh" +#include "cql3/dialect.hh" #include "exceptions/exceptions.hh" #include "service/migration_listener.hh" #include "timestamp.hh" @@ -138,10 +139,11 @@ public: static prepared_cache_key_type compute_id( std::string_view query_string, - std::string_view keyspace); + std::string_view keyspace, + dialect d); - static std::unique_ptr parse_statement(const std::string_view& query); - static std::vector> parse_statements(std::string_view queries); + static std::unique_ptr parse_statement(const std::string_view& query, dialect d); + static std::vector> parse_statements(std::string_view queries, dialect d); query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); @@ -250,10 +252,12 @@ public: execute_direct( const std::string_view& query_string, service::query_state& query_state, + dialect d, query_options& options) { return execute_direct_without_checking_exception_message( query_string, query_state, + d, options) .then(cql_transport::messages::propagate_exception_as_future<::shared_ptr>); } @@ -264,6 +268,7 @@ public: execute_direct_without_checking_exception_message( const std::string_view& query_string, service::query_state& query_state, + dialect d, query_options& options); future<::shared_ptr> @@ -398,10 +403,10 @@ public: future<::shared_ptr> - prepare(sstring query_string, service::query_state& query_state); + prepare(sstring query_string, service::query_state& query_state, dialect d); future<::shared_ptr> - prepare(sstring query_string, const service::client_state& client_state); + prepare(sstring query_string, const service::client_state& client_state, dialect d); future<> stop(); @@ -444,7 +449,8 @@ public: std::unique_ptr get_statement( const std::string_view& query, - const service::client_state& client_state); + const service::client_state& client_state, + dialect d); friend class migration_subscriber; @@ -528,14 +534,15 @@ private: prepare_one( sstring query_string, const service::client_state& client_state, + dialect d, PreparedKeyGenerator&& id_gen, IdGetter&& id_getter) { return do_with( id_gen(query_string, client_state.get_raw_keyspace()), std::move(query_string), - [this, &client_state, &id_getter](const prepared_cache_key_type& key, const sstring& query_string) { - return _prepared_cache.get(key, [this, &query_string, &client_state] { - auto prepared = get_statement(query_string, client_state); + [this, &client_state, &id_getter, d](const prepared_cache_key_type& key, const sstring& query_string) { + return _prepared_cache.get(key, [this, &query_string, &client_state, d] { + auto prepared = get_statement(query_string, client_state, d); auto bound_terms = prepared->statement->get_bound_terms(); if (bound_terms > std::numeric_limits::max()) { throw exceptions::invalid_request_exception( diff --git a/cql3/statements/alter_table_statement.cc b/cql3/statements/alter_table_statement.cc index e5f05f1c14..d0100d140a 100644 --- a/cql3/statements/alter_table_statement.cc +++ b/cql3/statements/alter_table_statement.cc @@ -385,7 +385,8 @@ std::pair> alter_table_statement::prepare_ auto new_where = util::rename_column_in_where_clause( view->view_info()->where_clause(), column_identifier::raw(view_from->text(), true), - column_identifier::raw(view_to->text(), true)); + column_identifier::raw(view_to->text(), true), + cql3::dialect{}); builder.with_view_info(view->view_info()->base_id(), view->view_info()->base_name(), view->view_info()->include_all_columns(), std::move(new_where)); diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index 194af80164..8970aaab1c 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -2591,7 +2591,9 @@ std::unique_ptr build_select_statement( if (!where_clause.empty()) { out << " WHERE " << where_clause << " ALLOW FILTERING"; } - return do_with_parser(out.str(), std::mem_fn(&cql3_parser::CqlParser::selectStatement)); + // In general it's not a good idea to use the default dialect, but here the database is talking to + // itself, so we can hope the dialects are mutually compatible here. + return do_with_parser(out.str(), dialect{}, std::mem_fn(&cql3_parser::CqlParser::selectStatement)); } } diff --git a/cql3/util.cc b/cql3/util.cc index 88d6d5a626..7aa65a67a7 100644 --- a/cql3/util.cc +++ b/cql3/util.cc @@ -21,7 +21,7 @@ void __sanitizer_finish_switch_fiber(void* fake_stack_save, const void** stack_b namespace cql3::util { -static void do_with_parser_impl_impl(const sstring_view& cql, noncopyable_function f) { +static void do_with_parser_impl_impl(const sstring_view& cql, dialect d, noncopyable_function f) { cql3_parser::CqlLexer::collector_type lexer_error_collector(cql); cql3_parser::CqlParser::collector_type parser_error_collector(cql); cql3_parser::CqlLexer::InputStreamType input{reinterpret_cast(cql.begin()), ANTLR_ENC_UTF8, static_cast(cql.size()), nullptr}; @@ -30,13 +30,14 @@ static void do_with_parser_impl_impl(const sstring_view& cql, noncopyable_functi cql3_parser::CqlParser::TokenStreamType tstream(ANTLR_SIZE_HINT, lexer.get_tokSource()); cql3_parser::CqlParser parser{&tstream}; parser.set_error_listener(parser_error_collector); + parser.set_dialect(d); f(parser); } #ifndef DEBUG -void do_with_parser_impl(const sstring_view& cql, noncopyable_function f) { - return do_with_parser_impl_impl(cql, std::move(f)); +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function f) { + return do_with_parser_impl_impl(cql, d, std::move(f)); } #else @@ -48,6 +49,7 @@ void do_with_parser_impl(const sstring_view& cql, noncopyable_function&& func; // Exceptions can't be returned from another stack, so store // any thrown exception here @@ -71,7 +73,7 @@ static void thunk(int p1, int p2) { // Complete stack switch started in do_with_parser_impl() __sanitizer_finish_switch_fiber(nullptr, &san.stack_bottom, &san.stack_size); try { - do_with_parser_impl_impl(args->cql, std::move(args->func)); + do_with_parser_impl_impl(args->cql, args->d, std::move(args->func)); } catch (...) { args->ex = std::current_exception(); } @@ -80,11 +82,12 @@ static void thunk(int p1, int p2) { setcontext(&args->caller_stack); }; -void do_with_parser_impl(const sstring_view& cql, noncopyable_function f) { +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function f) { static constexpr size_t stack_size = 1 << 20; static thread_local std::unique_ptr stack = std::make_unique(stack_size); thunk_args args{ .cql = cql, + .d = d, .func = std::move(f), }; ucontext_t uc; @@ -93,7 +96,7 @@ void do_with_parser_impl(const sstring_view& cql, noncopyable_function relations = boolean_factors(where_clause_to_relations(where_clause)); +sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to, dialect d) { + std::vector relations = boolean_factors(where_clause_to_relations(where_clause, d)); std::vector new_relations; new_relations.reserve(relations.size()); diff --git a/cql3/util.hh b/cql3/util.hh index c179e8b520..fa37b19a21 100644 --- a/cql3/util.hh +++ b/cql3/util.hh @@ -21,18 +21,19 @@ #include "cql3/CqlParser.hpp" #include "cql3/error_collector.hh" #include "cql3/statements/raw/select_statement.hh" +#include "cql3/dialect.hh" namespace cql3 { namespace util { -void do_with_parser_impl(const sstring_view& cql, noncopyable_function func); +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function func); template >> -Result do_with_parser(const sstring_view& cql, Func&& f) { +Result do_with_parser(const sstring_view& cql, dialect d, Func&& f) { std::optional ret; - do_with_parser_impl(cql, [&] (cql3_parser::CqlParser& parser) { + do_with_parser_impl(cql, d, [&] (cql3_parser::CqlParser& parser) { ret.emplace(f(parser)); }); return std::move(*ret); @@ -40,9 +41,9 @@ Result do_with_parser(const sstring_view& cql, Func&& f) { sstring relations_to_where_clause(const expr::expression& e); -expr::expression where_clause_to_relations(const sstring_view& where_clause); +expr::expression where_clause_to_relations(const sstring_view& where_clause, dialect d); -sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to); +sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to, dialect d); /// build a CQL "select" statement with the desired parameters. /// If select_all_columns==true, all columns are selected and the value of diff --git a/db/cql_type_parser.cc b/db/cql_type_parser.cc index a740e015ed..e8661dde5b 100644 --- a/db/cql_type_parser.cc +++ b/db/cql_type_parser.cc @@ -20,7 +20,9 @@ #include "utils/sorting.hh" static ::shared_ptr parse_raw(const sstring& str) { - return cql3::util::do_with_parser(str, + // In general it's a bad idea to use the default dialect, but type parsing + // should be dialect-agnostic. + return cql3::util::do_with_parser(str, cql3::dialect{}, [] (cql3_parser::CqlParser& parser) { return parser.comparator_type(true); }); diff --git a/db/schema_tables.cc b/db/schema_tables.cc index 85592e4c29..3a8289a5a4 100644 --- a/db/schema_tables.cc +++ b/db/schema_tables.cc @@ -1943,7 +1943,9 @@ static shared_ptr create_aggregate(replica::dat bytes_opt initcond = std::nullopt; if (initcond_str) { - auto expr = cql3::util::do_with_parser(*initcond_str, std::mem_fn(&cql3_parser::CqlParser::term)); + // In general using the default dialect is wrong, but here the database is communicating with itself, + // not the user, so any dialect should work. + auto expr = cql3::util::do_with_parser(*initcond_str, cql3::dialect{}, std::mem_fn(&cql3_parser::CqlParser::term)); auto dummy_ident = ::make_shared("", true); auto column_spec = make_lw_shared("", "", dummy_ident, state_type); auto raw = cql3::expr::evaluate(prepare_expression(expr, db.as_data_dictionary(), "", nullptr, {column_spec}), cql3::query_options::DEFAULT); diff --git a/table_helper.cc b/table_helper.cc index a6d6bb456a..5fc15d6013 100644 --- a/table_helper.cc +++ b/table_helper.cc @@ -22,7 +22,7 @@ static logging::logger tlogger("table_helper"); static schema_ptr parse_new_cf_statement(cql3::query_processor& qp, const sstring& create_cql) { auto db = qp.db(); - auto parsed = cql3::query_processor::parse_statement(create_cql); + auto parsed = cql3::query_processor::parse_statement(create_cql, cql3::dialect{}); cql3::statements::raw::cf_statement* parsed_cf_stmt = static_cast(parsed.get()); (void)parsed_cf_stmt->keyspace(); // This will SCYLLA_ASSERT if cql statement did not contain keyspace @@ -68,11 +68,11 @@ future<> table_helper::setup_table(cql3::query_processor& qp, service::migration } catch (...) {} } -future table_helper::try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs) { +future table_helper::try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs, cql3::dialect dialect) { // Note: `_insert_cql_fallback` is known to be engaged if `fallback` is true, see cache_table_info below. auto& stmt = fallback ? _insert_cql_fallback.value() : _insert_cql; try { - shared_ptr msg_ptr = co_await qp.prepare(stmt, qs.get_client_state()); + shared_ptr msg_ptr = co_await qp.prepare(stmt, qs.get_client_state(), dialect); _prepared_stmt = std::move(msg_ptr->get_prepared()); shared_ptr cql_stmt = _prepared_stmt->statement; _insert_stmt = dynamic_pointer_cast(cql_stmt); @@ -104,12 +104,12 @@ future<> table_helper::cache_table_info(cql3::query_processor& qp, service::migr } try { - bool success = co_await try_prepare(false, qp, qs); + bool success = co_await try_prepare(false, qp, qs, cql3::internal_dialect()); if (_is_fallback_stmt && _prepared_stmt) { co_return; } if (!success) { - co_await try_prepare(true, qp, qs); // Can only return true or exception when preparing the fallback statement + co_await try_prepare(true, qp, qs, cql3::internal_dialect()); // Can only return true or exception when preparing the fallback statement } } catch (...) { auto eptr = std::current_exception(); diff --git a/table_helper.hh b/table_helper.hh index c873948509..4e9217c8e9 100644 --- a/table_helper.hh +++ b/table_helper.hh @@ -18,6 +18,7 @@ class migration_manager; namespace cql3 { class query_processor; +class dialect; namespace statements { class modification_statement; }} @@ -43,7 +44,7 @@ private: bool _is_fallback_stmt = false; private: // Returns true is prepare succeeded, false if failed and there's still a chance to recover, exception if prepare failed and it's not possible to recover - future try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs); + future try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs, cql3::dialect dialect); public: table_helper(std::string_view keyspace, std::string_view name, sstring create_cql, sstring insert_cql, std::optional insert_cql_fallback = std::nullopt) : _keyspace(keyspace) diff --git a/test/boost/cql_auth_syntax_test.cc b/test/boost/cql_auth_syntax_test.cc index 3ae5a9c32c..09b96591a4 100644 --- a/test/boost/cql_auth_syntax_test.cc +++ b/test/boost/cql_auth_syntax_test.cc @@ -132,7 +132,7 @@ using modifier_rule_ptr = void (cql3_parser::CqlParser::*)(T&); template static T test_valid(std::string_view cql_fragment, producer_rule_ptr rule) { T v; - BOOST_REQUIRE_NO_THROW(v = cql3::util::do_with_parser(cql_fragment, std::mem_fn(rule))); + BOOST_REQUIRE_NO_THROW(v = cql3::util::do_with_parser(cql_fragment, cql3::dialect{}, std::mem_fn(rule))); return v; } @@ -143,7 +143,7 @@ static T test_valid(std::string_view cql_fragment, producer_rule_ptr rule) { template void test_valid(std::string_view cql_fragment, modifier_rule_ptr rule, T& v) { BOOST_REQUIRE_NO_THROW( - cql3::util::do_with_parser(cql_fragment, [rule, &v](cql3_parser::CqlParser& parser) { + cql3::util::do_with_parser(cql_fragment, cql3::dialect{}, [rule, &v](cql3_parser::CqlParser& parser) { (parser.*rule)(v); // Any non-`void` value will do. return 0; @@ -179,7 +179,7 @@ BOOST_AUTO_TEST_CASE(user_name) { // Not worth generalizing `test_valid`. BOOST_REQUIRE_THROW( - (cql3::util::do_with_parser("\"Ring-bearer\"", std::mem_fn(&cql3_parser::CqlParser::username))), + (cql3::util::do_with_parser("\"Ring-bearer\"", cql3::dialect{}, std::mem_fn(&cql3_parser::CqlParser::username))), exceptions::syntax_exception); } diff --git a/test/boost/expr_test.cc b/test/boost/expr_test.cc index e393476346..8f92b43a5d 100644 --- a/test/boost/expr_test.cc +++ b/test/boost/expr_test.cc @@ -345,7 +345,7 @@ BOOST_AUTO_TEST_CASE(expr_printer_parse_and_print_test) { }; for(const char* test : tests) { - expression parsed_where = cql3::util::where_clause_to_relations(test); + expression parsed_where = cql3::util::where_clause_to_relations(test, cql3::dialect{}); sstring printed_where = cql3::util::relations_to_where_clause(parsed_where); BOOST_REQUIRE_EQUAL(sstring(test), printed_where); diff --git a/test/boost/statement_restrictions_test.cc b/test/boost/statement_restrictions_test.cc index dfdc295bd1..c72133410a 100644 --- a/test/boost/statement_restrictions_test.cc +++ b/test/boost/statement_restrictions_test.cc @@ -47,7 +47,7 @@ query::clustering_row_ranges slice( query::clustering_row_ranges slice_parse( sstring_view where_clause, cql_test_env& env, const sstring& table_name = "t", const sstring& keyspace_name = "ks") { - return slice(boolean_factors(cql3::util::where_clause_to_relations(where_clause)), env, table_name, keyspace_name); + return slice(boolean_factors(cql3::util::where_clause_to_relations(where_clause, cql3::dialect{})), env, table_name, keyspace_name); } auto I(int32_t x) { return int32_type->decompose(x); } diff --git a/test/lib/cql_test_env.cc b/test/lib/cql_test_env.cc index 1eceefa63c..c46297910b 100644 --- a/test/lib/cql_test_env.cc +++ b/test/lib/cql_test_env.cc @@ -184,6 +184,11 @@ private: }; distributed _core_local; private: + cql3::dialect test_dialect() { + return cql3::dialect{ + }; + } + auto make_query_state() { if (_db.local().has_keyspace(ks_name)) { _core_local.local().client_state.set_keyspace(_db.local(), ks_name); @@ -219,7 +224,7 @@ public: testlog.trace("{}(\"{}\")", __FUNCTION__, text); auto qs = make_query_state(); auto qo = make_shared(cql3::query_options::DEFAULT); - return local_qp().execute_direct_without_checking_exception_message(text, *qs, *qo).then([qs, qo] (auto msg) { + return local_qp().execute_direct_without_checking_exception_message(text, *qs, test_dialect(), *qo).then([qs, qo] (auto msg) { return cql_transport::messages::propagate_exception_as_future(std::move(msg)); }); } @@ -231,7 +236,7 @@ public: testlog.trace("{}(\"{}\")", __FUNCTION__, text); auto qs = make_query_state(); auto& lqo = *qo; - return local_qp().execute_direct_without_checking_exception_message(text, *qs, lqo).then([qs, qo = std::move(qo)] (auto msg) { + return local_qp().execute_direct_without_checking_exception_message(text, *qs, test_dialect(), lqo).then([qs, qo = std::move(qo)] (auto msg) { return cql_transport::messages::propagate_exception_as_future(std::move(msg)); }); } @@ -239,9 +244,9 @@ public: virtual future prepare(sstring query) override { return qp().invoke_on_all([query, this] (auto& local_qp) { auto qs = this->make_query_state(); - return local_qp.prepare(query, *qs).finally([qs] {}).discard_result(); + return local_qp.prepare(query, *qs, test_dialect()).finally([qs] {}).discard_result(); }).then([query, this] { - return local_qp().compute_id(query, ks_name); + return local_qp().compute_id(query, ks_name, test_dialect()); }); } @@ -284,7 +289,7 @@ public: virtual future> get_modification_mutations(const sstring& text) override { auto qs = make_query_state(); - auto cql_stmt = local_qp().get_statement(text, qs->get_client_state())->statement; + auto cql_stmt = local_qp().get_statement(text, qs->get_client_state(), test_dialect())->statement; auto modif_stmt = dynamic_pointer_cast(std::move(cql_stmt)); if (!modif_stmt) { throw std::runtime_error(format("get_stmt_mutations: not a modification statement: {}", text)); @@ -1031,7 +1036,7 @@ public: using cql3::statements::modification_statement; std::vector modifications; boost::transform(queries, back_inserter(modifications), [this](const auto& query) { - auto stmt = local_qp().get_statement(query, _core_local.local().client_state); + auto stmt = local_qp().get_statement(query, _core_local.local().client_state, test_dialect()); if (!dynamic_cast(stmt->statement.get())) { throw exceptions::invalid_request_exception( "Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."); diff --git a/tools/schema_loader.cc b/tools/schema_loader.cc index 0180903ca5..314c033701 100644 --- a/tools/schema_loader.cc +++ b/tools/schema_loader.cc @@ -268,7 +268,8 @@ std::vector do_load_schemas(const db::config& cfg, std::string_view // fall-though to below } auto raw_statement = cql3::query_processor::parse_statement( - fmt::format("CREATE KEYSPACE {} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}", name)); + fmt::format("CREATE KEYSPACE {} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}", name), + cql3::dialect{}); auto prepared_statement = raw_statement->prepare(db, cql_stats); auto* statement = prepared_statement->statement.get(); auto p = dynamic_cast(statement); @@ -280,7 +281,7 @@ std::vector do_load_schemas(const db::config& cfg, std::string_view std::vector> raw_statements; try { - raw_statements = cql3::query_processor::parse_statements(schema_str); + raw_statements = cql3::query_processor::parse_statements(schema_str, cql3::dialect{}); } catch (...) { throw std::runtime_error(format("tools:do_load_schemas(): failed to parse CQL statements: {}", std::current_exception())); } diff --git a/transport/server.cc b/transport/server.cc index dc001404e1..897bdc4544 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -964,20 +964,20 @@ make_result(int16_t stream, messages::result_message& msg, const tracing::trace_ template future cql_server::connection::process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn) { + service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) { return _server.container().invoke_on(*bounce_msg->move_to_shard(), _server._config.bounce_request_smp_service_group, [this, is = std::move(is), cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, gt = tracing::global_trace_state_ptr(std::move(trace_state)), - cached_vals = std::move(bounce_msg->take_cached_pk_function_calls())] (cql_server& server) { + cached_vals = std::move(bounce_msg->take_cached_pk_function_calls()), dialect] (cql_server& server) { service::client_state client_state = cs.get(); return do_with(bytes_ostream(), std::move(client_state), std::move(cached_vals), [this, &server, is = std::move(is), stream, process_fn, - trace_state = tracing::trace_state_ptr(gt)] (bytes_ostream& linearization_buffer, + trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer, service::client_state& client_state, cql3::computed_function_values& cached_vals) mutable { request_reader in(is, linearization_buffer); return process_fn(client_state, server._query_processor, in, stream, _version, - /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals)).then([] (auto msg) { + /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals), dialect).then([] (auto msg) { // result here has to be foreign ptr return std::get(std::move(msg)); }); @@ -999,13 +999,14 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli tracing::trace_state_ptr trace_state, Process process_fn) { fragmented_temporary_buffer::istream is = in.get_stream(); + auto dialect = get_dialect(); return process_fn(client_state, _server._query_processor, in, stream, - _version, permit, trace_state, true, {}) - .then([stream, &client_state, this, is, permit, process_fn, trace_state] + _version, permit, trace_state, true, {}, dialect) + .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] (process_fn_return_type msg) mutable { auto* bounce_msg = std::get_if>(&msg); if (bounce_msg) { - return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, process_fn); + return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn); } auto ptr = std::get(std::move(msg)); return make_ready_future(std::move(ptr)); @@ -1015,7 +1016,8 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli static future process_query_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, + cql3::dialect dialect) { auto query = in.read_long_string_view(); auto q_state = std::make_unique(client_state, trace_state, std::move(permit)); auto& query_state = q_state->query_state; @@ -1036,7 +1038,7 @@ process_query_internal(service::client_state& client_state, distributedmove_to_shard()) { return process_fn_return_type(dynamic_pointer_cast(msg)); } else if (msg->is_exception()) { @@ -1057,15 +1059,16 @@ future> cql_server::connection::process_pr tracing::trace_state_ptr trace_state) { auto query = sstring(in.read_long_string_view()); + auto dialect = get_dialect(); tracing::add_query(trace_state, query); tracing::begin(trace_state, "Preparing CQL3 query", client_state.get_client_address()); - return _server._query_processor.invoke_on_others([query, &client_state] (auto& qp) mutable { - return qp.prepare(std::move(query), client_state).discard_result(); - }).then([this, query, stream, &client_state, trace_state] () mutable { + return _server._query_processor.invoke_on_others([query, &client_state, dialect] (auto& qp) mutable { + return qp.prepare(std::move(query), client_state, dialect).discard_result(); + }).then([this, query, stream, &client_state, trace_state, dialect] () mutable { tracing::trace(trace_state, "Done preparing on remote shards"); - return _server._query_processor.local().prepare(std::move(query), client_state).then([this, stream, trace_state] (auto msg) { + return _server._query_processor.local().prepare(std::move(query), client_state, dialect).then([this, stream, trace_state] (auto msg) { tracing::trace(trace_state, "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] { return messages::result_message::prepared::cql::get_id(msg); })); @@ -1077,8 +1080,9 @@ future> cql_server::connection::process_pr static future process_execute_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { - cql3::prepared_cache_key_type cache_key(in.read_short_bytes()); + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, + cql3::dialect dialect) { + cql3::prepared_cache_key_type cache_key(in.read_short_bytes(), dialect); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); bool needs_authorization = false; @@ -1152,7 +1156,7 @@ future cql_server::connection::pro static future process_batch_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) { const auto type = in.read_byte(); const unsigned n = in.read_short(); @@ -1177,7 +1181,7 @@ process_batch_internal(service::client_state& client_state, distributedchecked_weak_from_this(); if (init_trace) { tracing::add_query(trace_state, query); @@ -1185,7 +1189,7 @@ process_batch_internal(service::client_state& client_state, distributed cql_server::connection::process_batch(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state) { diff --git a/transport/server.hh b/transport/server.hh index aaca7f82e8..ee1135ac82 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -29,6 +29,7 @@ #include "generic_server.hh" #include "service/query_state.hh" #include "cql3/query_options.hh" +#include "cql3/dialect.hh" #include "transport/messages/result_message.hh" #include "utils/chunked_vector.hh" #include "exceptions/coordinator_result.hh" @@ -275,6 +276,8 @@ private: std::unique_ptr make_auth_success(int16_t, bytes, const tracing::trace_state_ptr& tr_state) const; std::unique_ptr make_auth_challenge(int16_t, bytes, const tracing::trace_state_ptr& tr_state) const; + cql3::dialect get_dialect() const; + // Helper functions to encapsulate bounce_to_shard processing for query, execute and batch verbs template future @@ -283,7 +286,7 @@ private: template future process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, - service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn); + service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn); void write_response(foreign_ptr>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none);