diff --git a/auth/common.cc b/auth/common.cc index 1cf3a6a200..7e0b94ef19 100644 --- a/auth/common.cc +++ b/auth/common.cc @@ -72,7 +72,7 @@ static future<> create_legacy_metadata_table_if_missing_impl( SCYLLA_ASSERT(this_shard_id() == 0); // once_among_shards makes sure a function is executed on shard 0 only auto db = qp.db(); - auto parsed_statement = cql3::query_processor::parse_statement(cql); + auto parsed_statement = cql3::query_processor::parse_statement(cql, cql3::dialect{}); auto& parsed_cf_statement = static_cast(*parsed_statement); parsed_cf_statement.prepare_keyspace(meta::legacy::AUTH_KS); diff --git a/cql3/Cql.g b/cql3/Cql.g index b99e97250a..39341b28f3 100644 --- a/cql3/Cql.g +++ b/cql3/Cql.g @@ -68,6 +68,7 @@ options { #include "cql3/statements/ks_prop_defs.hh" #include "cql3/selection/raw_selector.hh" #include "cql3/selection/selectable-expr.hh" +#include "cql3/dialect.hh" #include "cql3/keyspace_element_name.hh" #include "cql3/constants.hh" #include "cql3/operation_impl.hh" @@ -148,6 +149,8 @@ using uexpression = uninitialized; listener_type* listener; + dialect _dialect; + // Keeps the names of all bind variables. For bind variables without a name ('?'), the name is nullptr. // Maps bind_index -> name. std::vector<::shared_ptr> _bind_variable_names; @@ -171,6 +174,10 @@ using uexpression = uninitialized; return s; } + void set_dialect(dialect d) { + _dialect = d; + } + bind_variable new_bind_variables(shared_ptr name) { if (name && _named_bind_variables_indexes.contains(*name)) { diff --git a/cql3/cql3_type.cc b/cql3/cql3_type.cc index ff9bbbb559..115b12655e 100644 --- a/cql3/cql3_type.cc +++ b/cql3/cql3_type.cc @@ -449,7 +449,8 @@ sstring maybe_quote(const sstring& identifier) { // many keywords but allow keywords listed as "unreserved keywords". // So we can use any of them, for example cident. try { - cql3::util::do_with_parser(identifier, std::mem_fn(&cql3_parser::CqlParser::cident)); + // In general it's not a good idea to use the default dialect, but for parsing an identifier, it's okay. + cql3::util::do_with_parser(identifier, dialect{}, std::mem_fn(&cql3_parser::CqlParser::cident)); return identifier; } catch(exceptions::syntax_exception&) { // This alphanumeric string is not a valid identifier, so fall diff --git a/cql3/dialect.hh b/cql3/dialect.hh new file mode 100644 index 0000000000..a714426ba0 --- /dev/null +++ b/cql3/dialect.hh @@ -0,0 +1,31 @@ +// Copyright (C) 2024-present ScyllaDB +// SPDX-License-Identifier: AGPL-3.0-or-later + +#pragma once + +#include + +namespace cql3 { + +struct dialect { + bool operator==(const dialect&) const = default; +}; + +inline +dialect +internal_dialect() { + return dialect{ + }; +} + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + + template + auto format(const cql3::dialect& d, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "cql3::dialect{{}}"); + } +}; diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index 796dc918f1..7b041265b3 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -14,6 +14,7 @@ #include "utils/hash.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/column_specification.hh" +#include "cql3/dialect.hh" namespace cql3 { @@ -37,13 +38,17 @@ class prepared_cache_key_type { public: // derive from cql_prepared_id_type so we can customize the formatter of // cache_key_type - struct cache_key_type : public cql_prepared_id_type {}; + struct cache_key_type : public cql_prepared_id_type { + cache_key_type(cql_prepared_id_type&& id, cql3::dialect d) : cql_prepared_id_type(std::move(id)), dialect(d) {} + cql3::dialect dialect; // Not part of hash, but we don't expect collisions because of that + bool operator==(const cache_key_type& other) const = default; + }; private: cache_key_type _key; public: - explicit prepared_cache_key_type(cql_prepared_id_type cql_id) : _key(std::move(cql_id)) {} + explicit prepared_cache_key_type(cql_prepared_id_type cql_id, dialect d) : _key(std::move(cql_id), d) {} cache_key_type& key() { return _key; } const cache_key_type& key() const { return _key; } @@ -175,7 +180,7 @@ struct hash final { template <> struct fmt::formatter { constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } auto format(const cql3::prepared_cache_key_type::cache_key_type& p, fmt::format_context& ctx) const { - return fmt::format_to(ctx.out(), "{{cql_id: {}}}", static_cast(p)); + return fmt::format_to(ctx.out(), "{{cql_id: {}, dialect: {}}}", static_cast(p), p.dialect); } }; diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 4ab775d45e..50f8e00c74 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -565,10 +565,10 @@ query_processor::execute_maybe_with_guard(service::query_state& query_state, ::s } future<::shared_ptr> -query_processor::execute_direct_without_checking_exception_message(const sstring_view& query_string, service::query_state& query_state, query_options& options) { +query_processor::execute_direct_without_checking_exception_message(const sstring_view& query_string, service::query_state& query_state, dialect d, query_options& options) { log.trace("execute_direct: \"{}\"", query_string); tracing::trace(query_state.get_trace_state(), "Parsing a statement"); - auto p = get_statement(query_string, query_state.get_client_state()); + auto p = get_statement(query_string, query_state.get_client_state(), d); auto statement = p->statement; const auto warnings = std::move(p->warnings); if (statement->get_bound_terms() != options.get_values_count()) { @@ -652,18 +652,21 @@ query_processor::process_authorized_statement(const ::shared_ptr } future<::shared_ptr> -query_processor::prepare(sstring query_string, service::query_state& query_state) { +query_processor::prepare(sstring query_string, service::query_state& query_state, cql3::dialect d) { auto& client_state = query_state.get_client_state(); - return prepare(std::move(query_string), client_state); + return prepare(std::move(query_string), client_state, d); } future<::shared_ptr> -query_processor::prepare(sstring query_string, const service::client_state& client_state) { +query_processor::prepare(sstring query_string, const service::client_state& client_state, cql3::dialect d) { using namespace cql_transport::messages; return prepare_one( std::move(query_string), client_state, - compute_id, + d, + [d] (std::string_view query_string, std::string_view keyspace) { + return compute_id(query_string, keyspace, d); + }, prepared_cache_key_type::cql_id); } @@ -675,13 +678,14 @@ static std::string hash_target(std::string_view query_string, std::string_view k prepared_cache_key_type query_processor::compute_id( std::string_view query_string, - std::string_view keyspace) { - return prepared_cache_key_type(md5_hasher::calculate(hash_target(query_string, keyspace))); + std::string_view keyspace, + dialect d) { + return prepared_cache_key_type(md5_hasher::calculate(hash_target(query_string, keyspace)), d); } std::unique_ptr -query_processor::get_statement(const sstring_view& query, const service::client_state& client_state) { - std::unique_ptr statement = parse_statement(query); +query_processor::get_statement(const sstring_view& query, const service::client_state& client_state, dialect d) { + std::unique_ptr statement = parse_statement(query, d); // Set keyspace for statement that require login auto cf_stmt = dynamic_cast(statement.get()); @@ -695,7 +699,7 @@ query_processor::get_statement(const sstring_view& query, const service::client_ } std::unique_ptr -query_processor::parse_statement(const sstring_view& query) { +query_processor::parse_statement(const sstring_view& query, dialect d) { try { { const char* error_injection_key = "query_processor-parse_statement-test_failure"; @@ -705,7 +709,7 @@ query_processor::parse_statement(const sstring_view& query) { } }); } - auto statement = util::do_with_parser(query, std::mem_fn(&cql3_parser::CqlParser::query)); + auto statement = util::do_with_parser(query, d, std::mem_fn(&cql3_parser::CqlParser::query)); if (!statement) { throw exceptions::syntax_exception("Parsing failed"); } @@ -721,9 +725,9 @@ query_processor::parse_statement(const sstring_view& query) { } std::vector> -query_processor::parse_statements(std::string_view queries) { +query_processor::parse_statements(std::string_view queries, dialect d) { try { - auto statements = util::do_with_parser(queries, std::mem_fn(&cql3_parser::CqlParser::queries)); + auto statements = util::do_with_parser(queries, d, std::mem_fn(&cql3_parser::CqlParser::queries)); if (statements.empty()) { throw exceptions::syntax_exception("Parsing failed"); } @@ -796,7 +800,7 @@ query_options query_processor::make_internal_options( statements::prepared_statement::checked_weak_ptr query_processor::prepare_internal(const sstring& query_string) { auto& p = _internal_statements[query_string]; if (p == nullptr) { - auto np = parse_statement(query_string)->prepare(_db, _cql_stats); + auto np = parse_statement(query_string, internal_dialect())->prepare(_db, _cql_stats); np->statement->raw_cql_statement = query_string; p = std::move(np); // inserts it into map } @@ -902,7 +906,8 @@ query_processor::execute_internal( auto p = prepare_internal(query_string); return execute_with_params(std::move(p), cl, query_state, values); } else { - auto p = parse_statement(query_string)->prepare(_db, _cql_stats); + // For internal queries, we want the default dialect, not the user provided one + auto p = parse_statement(query_string, dialect{})->prepare(_db, _cql_stats); p->statement->raw_cql_statement = query_string; auto checked_weak_ptr = p->checked_weak_from_this(); return execute_with_params(std::move(checked_weak_ptr), cl, query_state, values).finally([p = std::move(p)] {}); diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index b0e6d4fb6a..2677ad004e 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -21,6 +21,7 @@ #include "cql3/authorized_prepared_statements_cache.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/cql_statement.hh" +#include "cql3/dialect.hh" #include "exceptions/exceptions.hh" #include "service/migration_listener.hh" #include "timestamp.hh" @@ -138,10 +139,11 @@ public: static prepared_cache_key_type compute_id( std::string_view query_string, - std::string_view keyspace); + std::string_view keyspace, + dialect d); - static std::unique_ptr parse_statement(const std::string_view& query); - static std::vector> parse_statements(std::string_view queries); + static std::unique_ptr parse_statement(const std::string_view& query, dialect d); + static std::vector> parse_statements(std::string_view queries, dialect d); query_processor(service::storage_proxy& proxy, data_dictionary::database db, service::migration_notifier& mn, memory_config mcfg, cql_config& cql_cfg, utils::loading_cache_config auth_prep_cache_cfg, lang::manager& langm); @@ -250,10 +252,12 @@ public: execute_direct( const std::string_view& query_string, service::query_state& query_state, + dialect d, query_options& options) { return execute_direct_without_checking_exception_message( query_string, query_state, + d, options) .then(cql_transport::messages::propagate_exception_as_future<::shared_ptr>); } @@ -264,6 +268,7 @@ public: execute_direct_without_checking_exception_message( const std::string_view& query_string, service::query_state& query_state, + dialect d, query_options& options); future<::shared_ptr> @@ -398,10 +403,10 @@ public: future<::shared_ptr> - prepare(sstring query_string, service::query_state& query_state); + prepare(sstring query_string, service::query_state& query_state, dialect d); future<::shared_ptr> - prepare(sstring query_string, const service::client_state& client_state); + prepare(sstring query_string, const service::client_state& client_state, dialect d); future<> stop(); @@ -444,7 +449,8 @@ public: std::unique_ptr get_statement( const std::string_view& query, - const service::client_state& client_state); + const service::client_state& client_state, + dialect d); friend class migration_subscriber; @@ -528,14 +534,15 @@ private: prepare_one( sstring query_string, const service::client_state& client_state, + dialect d, PreparedKeyGenerator&& id_gen, IdGetter&& id_getter) { return do_with( id_gen(query_string, client_state.get_raw_keyspace()), std::move(query_string), - [this, &client_state, &id_getter](const prepared_cache_key_type& key, const sstring& query_string) { - return _prepared_cache.get(key, [this, &query_string, &client_state] { - auto prepared = get_statement(query_string, client_state); + [this, &client_state, &id_getter, d](const prepared_cache_key_type& key, const sstring& query_string) { + return _prepared_cache.get(key, [this, &query_string, &client_state, d] { + auto prepared = get_statement(query_string, client_state, d); auto bound_terms = prepared->statement->get_bound_terms(); if (bound_terms > std::numeric_limits::max()) { throw exceptions::invalid_request_exception( diff --git a/cql3/statements/alter_table_statement.cc b/cql3/statements/alter_table_statement.cc index e5f05f1c14..d0100d140a 100644 --- a/cql3/statements/alter_table_statement.cc +++ b/cql3/statements/alter_table_statement.cc @@ -385,7 +385,8 @@ std::pair> alter_table_statement::prepare_ auto new_where = util::rename_column_in_where_clause( view->view_info()->where_clause(), column_identifier::raw(view_from->text(), true), - column_identifier::raw(view_to->text(), true)); + column_identifier::raw(view_to->text(), true), + cql3::dialect{}); builder.with_view_info(view->view_info()->base_id(), view->view_info()->base_name(), view->view_info()->include_all_columns(), std::move(new_where)); diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index 194af80164..8970aaab1c 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -2591,7 +2591,9 @@ std::unique_ptr build_select_statement( if (!where_clause.empty()) { out << " WHERE " << where_clause << " ALLOW FILTERING"; } - return do_with_parser(out.str(), std::mem_fn(&cql3_parser::CqlParser::selectStatement)); + // In general it's not a good idea to use the default dialect, but here the database is talking to + // itself, so we can hope the dialects are mutually compatible here. + return do_with_parser(out.str(), dialect{}, std::mem_fn(&cql3_parser::CqlParser::selectStatement)); } } diff --git a/cql3/util.cc b/cql3/util.cc index 88d6d5a626..7aa65a67a7 100644 --- a/cql3/util.cc +++ b/cql3/util.cc @@ -21,7 +21,7 @@ void __sanitizer_finish_switch_fiber(void* fake_stack_save, const void** stack_b namespace cql3::util { -static void do_with_parser_impl_impl(const sstring_view& cql, noncopyable_function f) { +static void do_with_parser_impl_impl(const sstring_view& cql, dialect d, noncopyable_function f) { cql3_parser::CqlLexer::collector_type lexer_error_collector(cql); cql3_parser::CqlParser::collector_type parser_error_collector(cql); cql3_parser::CqlLexer::InputStreamType input{reinterpret_cast(cql.begin()), ANTLR_ENC_UTF8, static_cast(cql.size()), nullptr}; @@ -30,13 +30,14 @@ static void do_with_parser_impl_impl(const sstring_view& cql, noncopyable_functi cql3_parser::CqlParser::TokenStreamType tstream(ANTLR_SIZE_HINT, lexer.get_tokSource()); cql3_parser::CqlParser parser{&tstream}; parser.set_error_listener(parser_error_collector); + parser.set_dialect(d); f(parser); } #ifndef DEBUG -void do_with_parser_impl(const sstring_view& cql, noncopyable_function f) { - return do_with_parser_impl_impl(cql, std::move(f)); +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function f) { + return do_with_parser_impl_impl(cql, d, std::move(f)); } #else @@ -48,6 +49,7 @@ void do_with_parser_impl(const sstring_view& cql, noncopyable_function&& func; // Exceptions can't be returned from another stack, so store // any thrown exception here @@ -71,7 +73,7 @@ static void thunk(int p1, int p2) { // Complete stack switch started in do_with_parser_impl() __sanitizer_finish_switch_fiber(nullptr, &san.stack_bottom, &san.stack_size); try { - do_with_parser_impl_impl(args->cql, std::move(args->func)); + do_with_parser_impl_impl(args->cql, args->d, std::move(args->func)); } catch (...) { args->ex = std::current_exception(); } @@ -80,11 +82,12 @@ static void thunk(int p1, int p2) { setcontext(&args->caller_stack); }; -void do_with_parser_impl(const sstring_view& cql, noncopyable_function f) { +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function f) { static constexpr size_t stack_size = 1 << 20; static thread_local std::unique_ptr stack = std::make_unique(stack_size); thunk_args args{ .cql = cql, + .d = d, .func = std::move(f), }; ucontext_t uc; @@ -93,7 +96,7 @@ void do_with_parser_impl(const sstring_view& cql, noncopyable_function relations = boolean_factors(where_clause_to_relations(where_clause)); +sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to, dialect d) { + std::vector relations = boolean_factors(where_clause_to_relations(where_clause, d)); std::vector new_relations; new_relations.reserve(relations.size()); diff --git a/cql3/util.hh b/cql3/util.hh index c179e8b520..fa37b19a21 100644 --- a/cql3/util.hh +++ b/cql3/util.hh @@ -21,18 +21,19 @@ #include "cql3/CqlParser.hpp" #include "cql3/error_collector.hh" #include "cql3/statements/raw/select_statement.hh" +#include "cql3/dialect.hh" namespace cql3 { namespace util { -void do_with_parser_impl(const sstring_view& cql, noncopyable_function func); +void do_with_parser_impl(const sstring_view& cql, dialect d, noncopyable_function func); template >> -Result do_with_parser(const sstring_view& cql, Func&& f) { +Result do_with_parser(const sstring_view& cql, dialect d, Func&& f) { std::optional ret; - do_with_parser_impl(cql, [&] (cql3_parser::CqlParser& parser) { + do_with_parser_impl(cql, d, [&] (cql3_parser::CqlParser& parser) { ret.emplace(f(parser)); }); return std::move(*ret); @@ -40,9 +41,9 @@ Result do_with_parser(const sstring_view& cql, Func&& f) { sstring relations_to_where_clause(const expr::expression& e); -expr::expression where_clause_to_relations(const sstring_view& where_clause); +expr::expression where_clause_to_relations(const sstring_view& where_clause, dialect d); -sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to); +sstring rename_column_in_where_clause(const sstring_view& where_clause, column_identifier::raw from, column_identifier::raw to, dialect d); /// build a CQL "select" statement with the desired parameters. /// If select_all_columns==true, all columns are selected and the value of diff --git a/db/cql_type_parser.cc b/db/cql_type_parser.cc index a740e015ed..e8661dde5b 100644 --- a/db/cql_type_parser.cc +++ b/db/cql_type_parser.cc @@ -20,7 +20,9 @@ #include "utils/sorting.hh" static ::shared_ptr parse_raw(const sstring& str) { - return cql3::util::do_with_parser(str, + // In general it's a bad idea to use the default dialect, but type parsing + // should be dialect-agnostic. + return cql3::util::do_with_parser(str, cql3::dialect{}, [] (cql3_parser::CqlParser& parser) { return parser.comparator_type(true); }); diff --git a/db/schema_tables.cc b/db/schema_tables.cc index 85592e4c29..3a8289a5a4 100644 --- a/db/schema_tables.cc +++ b/db/schema_tables.cc @@ -1943,7 +1943,9 @@ static shared_ptr create_aggregate(replica::dat bytes_opt initcond = std::nullopt; if (initcond_str) { - auto expr = cql3::util::do_with_parser(*initcond_str, std::mem_fn(&cql3_parser::CqlParser::term)); + // In general using the default dialect is wrong, but here the database is communicating with itself, + // not the user, so any dialect should work. + auto expr = cql3::util::do_with_parser(*initcond_str, cql3::dialect{}, std::mem_fn(&cql3_parser::CqlParser::term)); auto dummy_ident = ::make_shared("", true); auto column_spec = make_lw_shared("", "", dummy_ident, state_type); auto raw = cql3::expr::evaluate(prepare_expression(expr, db.as_data_dictionary(), "", nullptr, {column_spec}), cql3::query_options::DEFAULT); diff --git a/table_helper.cc b/table_helper.cc index a6d6bb456a..5fc15d6013 100644 --- a/table_helper.cc +++ b/table_helper.cc @@ -22,7 +22,7 @@ static logging::logger tlogger("table_helper"); static schema_ptr parse_new_cf_statement(cql3::query_processor& qp, const sstring& create_cql) { auto db = qp.db(); - auto parsed = cql3::query_processor::parse_statement(create_cql); + auto parsed = cql3::query_processor::parse_statement(create_cql, cql3::dialect{}); cql3::statements::raw::cf_statement* parsed_cf_stmt = static_cast(parsed.get()); (void)parsed_cf_stmt->keyspace(); // This will SCYLLA_ASSERT if cql statement did not contain keyspace @@ -68,11 +68,11 @@ future<> table_helper::setup_table(cql3::query_processor& qp, service::migration } catch (...) {} } -future table_helper::try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs) { +future table_helper::try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs, cql3::dialect dialect) { // Note: `_insert_cql_fallback` is known to be engaged if `fallback` is true, see cache_table_info below. auto& stmt = fallback ? _insert_cql_fallback.value() : _insert_cql; try { - shared_ptr msg_ptr = co_await qp.prepare(stmt, qs.get_client_state()); + shared_ptr msg_ptr = co_await qp.prepare(stmt, qs.get_client_state(), dialect); _prepared_stmt = std::move(msg_ptr->get_prepared()); shared_ptr cql_stmt = _prepared_stmt->statement; _insert_stmt = dynamic_pointer_cast(cql_stmt); @@ -104,12 +104,12 @@ future<> table_helper::cache_table_info(cql3::query_processor& qp, service::migr } try { - bool success = co_await try_prepare(false, qp, qs); + bool success = co_await try_prepare(false, qp, qs, cql3::internal_dialect()); if (_is_fallback_stmt && _prepared_stmt) { co_return; } if (!success) { - co_await try_prepare(true, qp, qs); // Can only return true or exception when preparing the fallback statement + co_await try_prepare(true, qp, qs, cql3::internal_dialect()); // Can only return true or exception when preparing the fallback statement } } catch (...) { auto eptr = std::current_exception(); diff --git a/table_helper.hh b/table_helper.hh index c873948509..4e9217c8e9 100644 --- a/table_helper.hh +++ b/table_helper.hh @@ -18,6 +18,7 @@ class migration_manager; namespace cql3 { class query_processor; +class dialect; namespace statements { class modification_statement; }} @@ -43,7 +44,7 @@ private: bool _is_fallback_stmt = false; private: // Returns true is prepare succeeded, false if failed and there's still a chance to recover, exception if prepare failed and it's not possible to recover - future try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs); + future try_prepare(bool fallback, cql3::query_processor& qp, service::query_state& qs, cql3::dialect dialect); public: table_helper(std::string_view keyspace, std::string_view name, sstring create_cql, sstring insert_cql, std::optional insert_cql_fallback = std::nullopt) : _keyspace(keyspace) diff --git a/test/boost/cql_auth_syntax_test.cc b/test/boost/cql_auth_syntax_test.cc index 3ae5a9c32c..09b96591a4 100644 --- a/test/boost/cql_auth_syntax_test.cc +++ b/test/boost/cql_auth_syntax_test.cc @@ -132,7 +132,7 @@ using modifier_rule_ptr = void (cql3_parser::CqlParser::*)(T&); template static T test_valid(std::string_view cql_fragment, producer_rule_ptr rule) { T v; - BOOST_REQUIRE_NO_THROW(v = cql3::util::do_with_parser(cql_fragment, std::mem_fn(rule))); + BOOST_REQUIRE_NO_THROW(v = cql3::util::do_with_parser(cql_fragment, cql3::dialect{}, std::mem_fn(rule))); return v; } @@ -143,7 +143,7 @@ static T test_valid(std::string_view cql_fragment, producer_rule_ptr rule) { template void test_valid(std::string_view cql_fragment, modifier_rule_ptr rule, T& v) { BOOST_REQUIRE_NO_THROW( - cql3::util::do_with_parser(cql_fragment, [rule, &v](cql3_parser::CqlParser& parser) { + cql3::util::do_with_parser(cql_fragment, cql3::dialect{}, [rule, &v](cql3_parser::CqlParser& parser) { (parser.*rule)(v); // Any non-`void` value will do. return 0; @@ -179,7 +179,7 @@ BOOST_AUTO_TEST_CASE(user_name) { // Not worth generalizing `test_valid`. BOOST_REQUIRE_THROW( - (cql3::util::do_with_parser("\"Ring-bearer\"", std::mem_fn(&cql3_parser::CqlParser::username))), + (cql3::util::do_with_parser("\"Ring-bearer\"", cql3::dialect{}, std::mem_fn(&cql3_parser::CqlParser::username))), exceptions::syntax_exception); } diff --git a/test/boost/expr_test.cc b/test/boost/expr_test.cc index e393476346..8f92b43a5d 100644 --- a/test/boost/expr_test.cc +++ b/test/boost/expr_test.cc @@ -345,7 +345,7 @@ BOOST_AUTO_TEST_CASE(expr_printer_parse_and_print_test) { }; for(const char* test : tests) { - expression parsed_where = cql3::util::where_clause_to_relations(test); + expression parsed_where = cql3::util::where_clause_to_relations(test, cql3::dialect{}); sstring printed_where = cql3::util::relations_to_where_clause(parsed_where); BOOST_REQUIRE_EQUAL(sstring(test), printed_where); diff --git a/test/boost/statement_restrictions_test.cc b/test/boost/statement_restrictions_test.cc index dfdc295bd1..c72133410a 100644 --- a/test/boost/statement_restrictions_test.cc +++ b/test/boost/statement_restrictions_test.cc @@ -47,7 +47,7 @@ query::clustering_row_ranges slice( query::clustering_row_ranges slice_parse( sstring_view where_clause, cql_test_env& env, const sstring& table_name = "t", const sstring& keyspace_name = "ks") { - return slice(boolean_factors(cql3::util::where_clause_to_relations(where_clause)), env, table_name, keyspace_name); + return slice(boolean_factors(cql3::util::where_clause_to_relations(where_clause, cql3::dialect{})), env, table_name, keyspace_name); } auto I(int32_t x) { return int32_type->decompose(x); } diff --git a/test/lib/cql_test_env.cc b/test/lib/cql_test_env.cc index 1eceefa63c..c46297910b 100644 --- a/test/lib/cql_test_env.cc +++ b/test/lib/cql_test_env.cc @@ -184,6 +184,11 @@ private: }; distributed _core_local; private: + cql3::dialect test_dialect() { + return cql3::dialect{ + }; + } + auto make_query_state() { if (_db.local().has_keyspace(ks_name)) { _core_local.local().client_state.set_keyspace(_db.local(), ks_name); @@ -219,7 +224,7 @@ public: testlog.trace("{}(\"{}\")", __FUNCTION__, text); auto qs = make_query_state(); auto qo = make_shared(cql3::query_options::DEFAULT); - return local_qp().execute_direct_without_checking_exception_message(text, *qs, *qo).then([qs, qo] (auto msg) { + return local_qp().execute_direct_without_checking_exception_message(text, *qs, test_dialect(), *qo).then([qs, qo] (auto msg) { return cql_transport::messages::propagate_exception_as_future(std::move(msg)); }); } @@ -231,7 +236,7 @@ public: testlog.trace("{}(\"{}\")", __FUNCTION__, text); auto qs = make_query_state(); auto& lqo = *qo; - return local_qp().execute_direct_without_checking_exception_message(text, *qs, lqo).then([qs, qo = std::move(qo)] (auto msg) { + return local_qp().execute_direct_without_checking_exception_message(text, *qs, test_dialect(), lqo).then([qs, qo = std::move(qo)] (auto msg) { return cql_transport::messages::propagate_exception_as_future(std::move(msg)); }); } @@ -239,9 +244,9 @@ public: virtual future prepare(sstring query) override { return qp().invoke_on_all([query, this] (auto& local_qp) { auto qs = this->make_query_state(); - return local_qp.prepare(query, *qs).finally([qs] {}).discard_result(); + return local_qp.prepare(query, *qs, test_dialect()).finally([qs] {}).discard_result(); }).then([query, this] { - return local_qp().compute_id(query, ks_name); + return local_qp().compute_id(query, ks_name, test_dialect()); }); } @@ -284,7 +289,7 @@ public: virtual future> get_modification_mutations(const sstring& text) override { auto qs = make_query_state(); - auto cql_stmt = local_qp().get_statement(text, qs->get_client_state())->statement; + auto cql_stmt = local_qp().get_statement(text, qs->get_client_state(), test_dialect())->statement; auto modif_stmt = dynamic_pointer_cast(std::move(cql_stmt)); if (!modif_stmt) { throw std::runtime_error(format("get_stmt_mutations: not a modification statement: {}", text)); @@ -1031,7 +1036,7 @@ public: using cql3::statements::modification_statement; std::vector modifications; boost::transform(queries, back_inserter(modifications), [this](const auto& query) { - auto stmt = local_qp().get_statement(query, _core_local.local().client_state); + auto stmt = local_qp().get_statement(query, _core_local.local().client_state, test_dialect()); if (!dynamic_cast(stmt->statement.get())) { throw exceptions::invalid_request_exception( "Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."); diff --git a/tools/schema_loader.cc b/tools/schema_loader.cc index 0180903ca5..314c033701 100644 --- a/tools/schema_loader.cc +++ b/tools/schema_loader.cc @@ -268,7 +268,8 @@ std::vector do_load_schemas(const db::config& cfg, std::string_view // fall-though to below } auto raw_statement = cql3::query_processor::parse_statement( - fmt::format("CREATE KEYSPACE {} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}", name)); + fmt::format("CREATE KEYSPACE {} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}", name), + cql3::dialect{}); auto prepared_statement = raw_statement->prepare(db, cql_stats); auto* statement = prepared_statement->statement.get(); auto p = dynamic_cast(statement); @@ -280,7 +281,7 @@ std::vector do_load_schemas(const db::config& cfg, std::string_view std::vector> raw_statements; try { - raw_statements = cql3::query_processor::parse_statements(schema_str); + raw_statements = cql3::query_processor::parse_statements(schema_str, cql3::dialect{}); } catch (...) { throw std::runtime_error(format("tools:do_load_schemas(): failed to parse CQL statements: {}", std::current_exception())); } diff --git a/transport/server.cc b/transport/server.cc index dc001404e1..897bdc4544 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -964,20 +964,20 @@ make_result(int16_t stream, messages::result_message& msg, const tracing::trace_ template future cql_server::connection::process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, - service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn) { + service::client_state& cs, service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn) { return _server.container().invoke_on(*bounce_msg->move_to_shard(), _server._config.bounce_request_smp_service_group, [this, is = std::move(is), cs = cs.move_to_other_shard(), stream, permit = std::move(permit), process_fn, gt = tracing::global_trace_state_ptr(std::move(trace_state)), - cached_vals = std::move(bounce_msg->take_cached_pk_function_calls())] (cql_server& server) { + cached_vals = std::move(bounce_msg->take_cached_pk_function_calls()), dialect] (cql_server& server) { service::client_state client_state = cs.get(); return do_with(bytes_ostream(), std::move(client_state), std::move(cached_vals), [this, &server, is = std::move(is), stream, process_fn, - trace_state = tracing::trace_state_ptr(gt)] (bytes_ostream& linearization_buffer, + trace_state = tracing::trace_state_ptr(gt), dialect] (bytes_ostream& linearization_buffer, service::client_state& client_state, cql3::computed_function_values& cached_vals) mutable { request_reader in(is, linearization_buffer); return process_fn(client_state, server._query_processor, in, stream, _version, - /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals)).then([] (auto msg) { + /* FIXME */empty_service_permit(), std::move(trace_state), false, std::move(cached_vals), dialect).then([] (auto msg) { // result here has to be foreign ptr return std::get(std::move(msg)); }); @@ -999,13 +999,14 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli tracing::trace_state_ptr trace_state, Process process_fn) { fragmented_temporary_buffer::istream is = in.get_stream(); + auto dialect = get_dialect(); return process_fn(client_state, _server._query_processor, in, stream, - _version, permit, trace_state, true, {}) - .then([stream, &client_state, this, is, permit, process_fn, trace_state] + _version, permit, trace_state, true, {}, dialect) + .then([stream, &client_state, this, is, permit, process_fn, trace_state, dialect] (process_fn_return_type msg) mutable { auto* bounce_msg = std::get_if>(&msg); if (bounce_msg) { - return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, process_fn); + return process_on_shard(*bounce_msg, stream, is, client_state, std::move(permit), trace_state, dialect, process_fn); } auto ptr = std::get(std::move(msg)); return make_ready_future(std::move(ptr)); @@ -1015,7 +1016,8 @@ cql_server::connection::process(uint16_t stream, request_reader in, service::cli static future process_query_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, + cql3::dialect dialect) { auto query = in.read_long_string_view(); auto q_state = std::make_unique(client_state, trace_state, std::move(permit)); auto& query_state = q_state->query_state; @@ -1036,7 +1038,7 @@ process_query_internal(service::client_state& client_state, distributedmove_to_shard()) { return process_fn_return_type(dynamic_pointer_cast(msg)); } else if (msg->is_exception()) { @@ -1057,15 +1059,16 @@ future> cql_server::connection::process_pr tracing::trace_state_ptr trace_state) { auto query = sstring(in.read_long_string_view()); + auto dialect = get_dialect(); tracing::add_query(trace_state, query); tracing::begin(trace_state, "Preparing CQL3 query", client_state.get_client_address()); - return _server._query_processor.invoke_on_others([query, &client_state] (auto& qp) mutable { - return qp.prepare(std::move(query), client_state).discard_result(); - }).then([this, query, stream, &client_state, trace_state] () mutable { + return _server._query_processor.invoke_on_others([query, &client_state, dialect] (auto& qp) mutable { + return qp.prepare(std::move(query), client_state, dialect).discard_result(); + }).then([this, query, stream, &client_state, trace_state, dialect] () mutable { tracing::trace(trace_state, "Done preparing on remote shards"); - return _server._query_processor.local().prepare(std::move(query), client_state).then([this, stream, trace_state] (auto msg) { + return _server._query_processor.local().prepare(std::move(query), client_state, dialect).then([this, stream, trace_state] (auto msg) { tracing::trace(trace_state, "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] { return messages::result_message::prepared::cql::get_id(msg); })); @@ -1077,8 +1080,9 @@ future> cql_server::connection::process_pr static future process_execute_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { - cql3::prepared_cache_key_type cache_key(in.read_short_bytes()); + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, + cql3::dialect dialect) { + cql3::prepared_cache_key_type cache_key(in.read_short_bytes(), dialect); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); bool needs_authorization = false; @@ -1152,7 +1156,7 @@ future cql_server::connection::pro static future process_batch_internal(service::client_state& client_state, distributed& qp, request_reader in, uint16_t stream, cql_protocol_version_type version, - service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls) { + service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) { const auto type = in.read_byte(); const unsigned n = in.read_short(); @@ -1177,7 +1181,7 @@ process_batch_internal(service::client_state& client_state, distributedchecked_weak_from_this(); if (init_trace) { tracing::add_query(trace_state, query); @@ -1185,7 +1189,7 @@ process_batch_internal(service::client_state& client_state, distributed cql_server::connection::process_batch(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state) { diff --git a/transport/server.hh b/transport/server.hh index aaca7f82e8..ee1135ac82 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -29,6 +29,7 @@ #include "generic_server.hh" #include "service/query_state.hh" #include "cql3/query_options.hh" +#include "cql3/dialect.hh" #include "transport/messages/result_message.hh" #include "utils/chunked_vector.hh" #include "exceptions/coordinator_result.hh" @@ -275,6 +276,8 @@ private: std::unique_ptr make_auth_success(int16_t, bytes, const tracing::trace_state_ptr& tr_state) const; std::unique_ptr make_auth_challenge(int16_t, bytes, const tracing::trace_state_ptr& tr_state) const; + cql3::dialect get_dialect() const; + // Helper functions to encapsulate bounce_to_shard processing for query, execute and batch verbs template future @@ -283,7 +286,7 @@ private: template future process_on_shard(::shared_ptr bounce_msg, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs, - service_permit permit, tracing::trace_state_ptr trace_state, Process process_fn); + service_permit permit, tracing::trace_state_ptr trace_state, cql3::dialect dialect, Process process_fn); void write_response(foreign_ptr>&& response, service_permit permit = empty_service_permit(), cql_compression compression = cql_compression::none);