diff --git a/cql3/authorized_prepared_statements_cache.hh b/cql3/authorized_prepared_statements_cache.hh new file mode 100644 index 0000000000..4eff0b1c4a --- /dev/null +++ b/cql3/authorized_prepared_statements_cache.hh @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2018 ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#pragma once + +#include "cql3/prepared_statements_cache.hh" + +namespace cql3 { + +struct authorized_prepared_statements_cache_size { + size_t operator()(const statements::prepared_statement::checked_weak_ptr& val) { + // TODO: improve the size approximation - most of the entry is occupied by the key here. + return 100; + } +}; + +class authorized_prepared_statements_cache_key { +public: + using cache_key_type = std::pair; +private: + cache_key_type _key; + +public: + authorized_prepared_statements_cache_key(auth::authenticated_user user, cql3::prepared_cache_key_type prepared_cache_key) + : _key(std::move(user), std::move(prepared_cache_key.key())) {} + + cache_key_type& key() { return _key; } + + const cache_key_type& key() const { return _key; } + + bool operator==(const authorized_prepared_statements_cache_key& other) const { + return _key == other._key; + } + + bool operator!=(const authorized_prepared_statements_cache_key& other) const { + return !(*this == other); + } + + static size_t hash(const auth::authenticated_user& user, const cql3::prepared_cache_key_type::cache_key_type& prep_cache_key) { + return utils::hash_combine(std::hash()(user), utils::tuple_hash()(prep_cache_key)); + } +}; + +/// \class authorized_prepared_statements_cache +/// \brief A cache of previously authorized statements. +/// +/// Entries are inserted every time a new statement is authorized. +/// Entries are evicted in any of the following cases: +/// - When the corresponding prepared statement is not valid anymore. +/// - Periodically, with the same period as the permission cache is refreshed. +/// - If the corresponding entry hasn't been used for \ref entry_expiry. +class authorized_prepared_statements_cache { +public: + struct stats { + uint64_t authorized_prepared_statements_cache_evictions = 0; + }; + + static stats& shard_stats() { + static thread_local stats _stats; + return _stats; + } + + struct authorized_prepared_statements_cache_stats_updater { + static void inc_hits() noexcept {} + static void inc_misses() noexcept {} + static void inc_blocks() noexcept {} + static void inc_evictions() noexcept { + ++shard_stats().authorized_prepared_statements_cache_evictions; + } + }; + +private: + using cache_key_type = authorized_prepared_statements_cache_key; + using checked_weak_ptr = typename statements::prepared_statement::checked_weak_ptr; + using cache_type = utils::loading_cache, + std::equal_to, + authorized_prepared_statements_cache_stats_updater>; + + static const std::chrono::minutes entry_expiry; + +public: + using key_type = cache_key_type; + using value_type = checked_weak_ptr; + using entry_is_too_big = typename cache_type::entry_is_too_big; + using iterator = typename cache_type::iterator; + +private: + cache_type _cache; + logging::logger& _logger; + +public: + // Choose the memory budget such that would allow us ~4K entries when a shard gets 1GB of RAM + authorized_prepared_statements_cache(std::chrono::milliseconds entry_refresh, logging::logger& logger) + : _cache(memory::stats().total_memory() / 2560, entry_expiry, entry_refresh, logger, [this] (const key_type& k) { + _cache.remove(k); + return make_ready_future(); + }) + , _logger(logger) + {} + + future<> insert(auth::authenticated_user user, cql3::prepared_cache_key_type prep_cache_key, value_type v) noexcept { + return _cache.get_ptr(key_type(std::move(user), std::move(prep_cache_key)), [v = std::move(v)] (const cache_key_type&) mutable { + return make_ready_future(std::move(v)); + }).discard_result(); + } + + iterator find(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) { + struct key_view { + const auth::authenticated_user& user_ref; + const cql3::prepared_cache_key_type& prep_cache_key_ref; + }; + + struct hasher { + size_t operator()(const key_view& kv) { + return cql3::authorized_prepared_statements_cache_key::hash(kv.user_ref, kv.prep_cache_key_ref.key()); + } + }; + + struct equal { + bool operator()(const key_type& k1, const key_view& k2) { + return k1.key().first == k2.user_ref && k1.key().second == k2.prep_cache_key_ref.key(); + } + + bool operator()(const key_view& k2, const key_type& k1) { + return operator()(k1, k2); + } + }; + + return _cache.find(key_view{user, prep_cache_key}, hasher(), equal()); + } + + iterator end() { + return _cache.end(); + } + + void remove(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) { + iterator it = find(user, prep_cache_key); + _cache.remove(it); + } + + size_t size() const { + return _cache.size(); + } + + size_t memory_footprint() const { + return _cache.memory_footprint(); + } + + future<> stop() { + return _cache.stop(); + } +}; + +} + +namespace std { +template <> +struct hash final { + size_t operator()(const cql3::authorized_prepared_statements_cache_key& k) const { + return cql3::authorized_prepared_statements_cache_key::hash(k.key().first, k.key().second); + } +}; + +inline std::ostream& operator<<(std::ostream& out, const cql3::authorized_prepared_statements_cache_key& k) { + return out << "{ " << k.key().first << ", " << k.key().second << " }"; +} +} \ No newline at end of file diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index c345fe520e..338186aff0 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -68,6 +68,14 @@ public: static thrift_prepared_id_type thrift_id(const prepared_cache_key_type& key) { return key.key().second; } + + bool operator==(const prepared_cache_key_type& other) const { + return _key == other._key; + } + + bool operator!=(const prepared_cache_key_type& other) const { + return !(*this == other); + } }; class prepared_statements_cache { @@ -155,6 +163,10 @@ public: size_t memory_footprint() const { return _cache.memory_footprint(); } + + future<> stop() { + return _cache.stop(); + } }; } @@ -168,4 +180,11 @@ inline std::ostream& operator<<(std::ostream& os, const cql3::prepared_cache_key os << p.key(); return os; } + +template<> +struct hash final { + size_t operator()(const cql3::prepared_cache_key_type& k) const { + return utils::tuple_hash()(k.key()); + } +}; } diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 693349653f..a6c657b7ac 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -58,12 +58,14 @@ using namespace cql_transport::messages; logging::logger log("query_processor"); logging::logger prep_cache_log("prepared_statements_cache"); +logging::logger authorized_prepared_statements_cache_log("authorized_prepared_statements_cache"); distributed _the_query_processor; const sstring query_processor::CQL_VERSION = "3.3.1"; const std::chrono::minutes prepared_statements_cache::entry_expiry = std::chrono::minutes(60); +const std::chrono::minutes authorized_prepared_statements_cache::entry_expiry = std::chrono::minutes(60); class query_processor::internal_state { service::query_state _qs; @@ -96,7 +98,8 @@ query_processor::query_processor(service::storage_proxy& proxy, distributed query_processor::stop() { service::get_local_migration_manager().unregister_listener(_migration_subscriber.get()); - return make_ready_future<>(); + return _authorized_prepared_cache.stop().finally([this] { return _prepared_cache.stop(); }); } future<::shared_ptr> @@ -230,33 +249,60 @@ query_processor::process(const sstring_view& query_string, service::query_state& metrics.regularStatementsExecuted.inc(); #endif tracing::trace(query_state.get_trace_state(), "Processing a statement"); - return process_statement(std::move(cql_statement), query_state, options); + return process_statement_unprepared(std::move(cql_statement), query_state, options); } future<::shared_ptr> -query_processor::process_statement( +query_processor::process_statement_unprepared( ::shared_ptr statement, service::query_state& query_state, const query_options& options) { - return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options]() { - auto& client_state = query_state.get_client_state(); + return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options] () mutable { + return process_authorized_statement(std::move(statement), query_state, options); + }); +} - statement->validate(_proxy, client_state); +future<::shared_ptr> +query_processor::process_statement_prepared( + statements::prepared_statement::checked_weak_ptr prepared, + cql3::prepared_cache_key_type cache_key, + service::query_state& query_state, + const query_options& options, + bool needs_authorization) { - auto fut = make_ready_future<::shared_ptr>(); - if (client_state.is_internal()) { - fut = statement->execute_internal(_proxy, query_state, options); - } else { - fut = statement->execute(_proxy, query_state, options); - } - - return fut.then([statement] (auto msg) { - if (msg) { - return make_ready_future<::shared_ptr>(std::move(msg)); - } - return make_ready_future<::shared_ptr>( - ::make_shared()); + ::shared_ptr statement = prepared->statement; + future<> fut = make_ready_future<>(); + if (needs_authorization) { + fut = statement->check_access(query_state.get_client_state()).then([this, &query_state, prepared = std::move(prepared), cache_key = std::move(cache_key)] () mutable { + return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), std::move(cache_key), std::move(prepared)).handle_exception([this] (auto eptr) { + log.error("failed to cache the entry", eptr); + }); }); + } + + return fut.then([this, statement = std::move(statement), &query_state, &options] () mutable { + return process_authorized_statement(std::move(statement), query_state, options); + }); +} + +future<::shared_ptr> +query_processor::process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options) { + auto& client_state = query_state.get_client_state(); + + statement->validate(_proxy, client_state); + + auto fut = make_ready_future<::shared_ptr>(); + if (client_state.is_internal()) { + fut = statement->execute_internal(_proxy, query_state, options); + } else { + fut = statement->execute(_proxy, query_state, options); + } + + return fut.then([statement] (auto msg) { + if (msg) { + return make_ready_future<::shared_ptr>(std::move(msg)); + } + return make_ready_future<::shared_ptr>(::make_shared()); }); } @@ -558,11 +604,18 @@ future<::shared_ptr> query_processor::process_batch( ::shared_ptr batch, service::query_state& query_state, - query_options& options) { - return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch] { - batch->validate(); - batch->validate(_proxy, query_state.get_client_state()); - return batch->execute(_proxy, query_state, options); + query_options& options, + std::unordered_map pending_authorization_entries) { + return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch, pending_authorization_entries = std::move(pending_authorization_entries)] () mutable { + return parallel_for_each(pending_authorization_entries, [this, &query_state] (auto& e) { + return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), e.first, std::move(e.second)).handle_exception([this] (auto eptr) { + log.error("failed to cache the entry", eptr); + }); + }).then([this, &query_state, &options, batch] { + batch->validate(); + batch->validate(_proxy, query_state.get_client_state()); + return batch->execute(_proxy, query_state, options); + }); }); } diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 1c6a2b1b18..2f671e0c96 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -49,6 +49,7 @@ #include #include "cql3/prepared_statements_cache.hh" +#include "cql3/authorized_prepared_statements_cache.hh" #include "cql3/query_options.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/statements/raw/parsed_statement.hh" @@ -117,6 +118,7 @@ private: std::unique_ptr _internal_state; prepared_statements_cache _prepared_cache; + authorized_prepared_statements_cache _authorized_prepared_cache; // A map for prepared statements used internally (which we don't want to mix with user statement, in particular we // don't bother with expiration on those. @@ -151,6 +153,21 @@ public: return _cql_stats; } + statements::prepared_statement::checked_weak_ptr get_prepared(const auth::authenticated_user* user_ptr, const prepared_cache_key_type& key) { + if (user_ptr) { + auto it = _authorized_prepared_cache.find(*user_ptr, key); + if (it != _authorized_prepared_cache.end()) { + try { + return it->get()->checked_weak_from_this(); + } catch (seastar::checked_ptr_is_null_exception&) { + // If the prepared statement got invalidated - remove the corresponding authorized_prepared_statements_cache entry as well. + _authorized_prepared_cache.remove(*user_ptr, key); + } + } + } + return statements::prepared_statement::checked_weak_ptr(); + } + statements::prepared_statement::checked_weak_ptr get_prepared(const prepared_cache_key_type& key) { auto it = _prepared_cache.find(key); if (it == _prepared_cache.end()) { @@ -160,11 +177,19 @@ public: } future<::shared_ptr> - process_statement( + process_statement_unprepared( ::shared_ptr statement, service::query_state& query_state, const query_options& options); + future<::shared_ptr> + process_statement_prepared( + statements::prepared_statement::checked_weak_ptr statement, + cql3::prepared_cache_key_type cache_key, + service::query_state& query_state, + const query_options& options, + bool needs_authorization); + future<::shared_ptr> process( const std::experimental::string_view& query_string, @@ -242,7 +267,11 @@ public: future<> stop(); future<::shared_ptr> - process_batch(::shared_ptr, service::query_state& query_state, query_options& options); + process_batch( + ::shared_ptr, + service::query_state& query_state, + query_options& options, + std::unordered_map pending_authorization_entries); std::unique_ptr get_statement( const std::experimental::string_view& query, @@ -257,6 +286,9 @@ private: db::consistency_level = db::consistency_level::ONE, int32_t page_size = -1); + future<::shared_ptr> + process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options); + /*! * \brief created a state object for paging * diff --git a/cql3/statements/batch_statement.cc b/cql3/statements/batch_statement.cc index 96d9db1f13..8916781b95 100644 --- a/cql3/statements/batch_statement.cc +++ b/cql3/statements/batch_statement.cc @@ -75,19 +75,19 @@ timeout_for_type(batch_statement::type t) { } batch_statement::batch_statement(int bound_terms, type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats) : cql_statement_no_metadata(timeout_for_type(type_)) , _bound_terms(bound_terms), _type(type_), _statements(std::move(statements)) , _attrs(std::move(attrs)) - , _has_conditions(boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::has_conditions))) + , _has_conditions(boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->has_conditions(); })) , _stats(stats) { } batch_statement::batch_statement(type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats) : batch_statement(-1, type_, std::move(statements), std::move(attrs), stats) @@ -97,7 +97,7 @@ batch_statement::batch_statement(type type_, bool batch_statement::uses_function(const sstring& ks_name, const sstring& function_name) const { return _attrs->uses_function(ks_name, function_name) - || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s->uses_function(ks_name, function_name); }); + || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s.statement->uses_function(ks_name, function_name); }); } bool batch_statement::depends_on_keyspace(const sstring& ks_name) const @@ -118,7 +118,11 @@ uint32_t batch_statement::get_bound_terms() future<> batch_statement::check_access(const service::client_state& state) { return parallel_for_each(_statements.begin(), _statements.end(), [&state](auto&& s) { - return s->check_access(state); + if (s.needs_authorization) { + return s.statement->check_access(state); + } else { + return make_ready_future<>(); + } }); } @@ -138,12 +142,12 @@ void batch_statement::validate() } } - bool has_counters = boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_counter)); - bool has_non_counters = !boost::algorithm::all_of(_statements, std::mem_fn(&modification_statement::is_counter)); + bool has_counters = boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_counter(); }); + bool has_non_counters = !boost::algorithm::all_of(_statements, [] (auto&& s) { return s.statement->is_counter(); }); if (timestamp_set && has_counters) { throw exceptions::invalid_request_exception("Cannot provide custom timestamp for a BATCH containing counters"); } - if (timestamp_set && boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_timestamp_set))) { + if (timestamp_set && boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_timestamp_set(); })) { throw exceptions::invalid_request_exception("Timestamp must be set either on BATCH or individual statements"); } if (_type == type::COUNTER && has_non_counters) { @@ -159,30 +163,30 @@ void batch_statement::validate() if (_has_conditions && !_statements.empty() && (boost::distance(_statements - | boost::adaptors::transformed(std::mem_fn(&modification_statement::keyspace)) + | boost::adaptors::transformed([] (auto&& s) { return s.statement->keyspace(); }) | boost::adaptors::uniqued) != 1 || (boost::distance(_statements - | boost::adaptors::transformed(std::mem_fn(&modification_statement::column_family)) + | boost::adaptors::transformed([] (auto&& s) { return s.statement->column_family(); }) | boost::adaptors::uniqued) != 1))) { throw exceptions::invalid_request_exception("Batch with conditions cannot span multiple tables"); } std::experimental::optional raw_counter; for (auto& s : _statements) { - if (raw_counter && s->is_raw_counter_shard_write() != *raw_counter) { + if (raw_counter && s.statement->is_raw_counter_shard_write() != *raw_counter) { throw exceptions::invalid_request_exception("Cannot mix raw and regular counter statements in batch"); } - raw_counter = s->is_raw_counter_shard_write(); + raw_counter = s.statement->is_raw_counter_shard_write(); } } void batch_statement::validate(service::storage_proxy& proxy, const service::client_state& state) { for (auto&& s : _statements) { - s->validate(proxy, state); + s.statement->validate(proxy, state); } } -const std::vector>& batch_statement::get_statements() +const std::vector& batch_statement::get_statements() { return _statements; } @@ -196,7 +200,7 @@ future> batch_statement::get_mutations(service::storage_pr return do_for_each(boost::make_counting_iterator(0), boost::make_counting_iterator(_statements.size()), [this, &storage, &options, now, local, &result, trace_state] (size_t i) { - auto&& statement = _statements[i]; + auto&& statement = _statements[i].statement; statement->inc_cql_stats(); auto&& statement_options = options.for_statement(i); auto timestamp = _attrs->get_timestamp(now, statement_options); @@ -426,7 +430,9 @@ batch_statement::prepare(database& db, cql_stats& stats) { stdx::optional first_cf; bool have_multiple_cfs = false; - std::vector> statements; + std::vector statements; + statements.reserve(_parsed_statements.size()); + for (auto&& parsed : _parsed_statements) { if (!first_ks) { first_ks = parsed->keyspace(); @@ -434,7 +440,7 @@ batch_statement::prepare(database& db, cql_stats& stats) { } else { have_multiple_cfs = first_ks.value() != parsed->keyspace() || first_cf.value() != parsed->column_family(); } - statements.push_back(parsed->prepare(db, bound_names, stats)); + statements.emplace_back(parsed->prepare(db, bound_names, stats)); } auto&& prep_attrs = _attrs->prepare(db, "[batch]", "[batch]"); @@ -445,7 +451,7 @@ batch_statement::prepare(database& db, cql_stats& stats) { std::vector partition_key_bind_indices; if (!have_multiple_cfs && batch_statement_.get_statements().size() > 0) { - partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0]->s); + partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0].statement->s); } return std::make_unique(make_shared(std::move(batch_statement_)), bound_names->get_specifications(), diff --git a/cql3/statements/batch_statement.hh b/cql3/statements/batch_statement.hh index cc0d4595b9..f31d24e79f 100644 --- a/cql3/statements/batch_statement.hh +++ b/cql3/statements/batch_statement.hh @@ -66,10 +66,24 @@ class batch_statement : public cql_statement_no_metadata { static logging::logger _logger; public: using type = raw::batch_statement::type; + + struct single_statement { + shared_ptr statement; + bool needs_authorization = true; + + public: + single_statement(shared_ptr s) + : statement(std::move(s)) + {} + single_statement(shared_ptr s, bool na) + : statement(std::move(s)) + , needs_authorization(na) + {} + }; private: int _bound_terms; type _type; - std::vector> _statements; + std::vector _statements; std::unique_ptr _attrs; bool _has_conditions; cql_stats& _stats; @@ -83,12 +97,12 @@ public: * @param attrs additional attributes for statement (CL, timestamp, timeToLive) */ batch_statement(int bound_terms, type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats); batch_statement(type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats); @@ -109,7 +123,7 @@ public: // or in QueryProcessor.processBatch() - for native protocol batches. virtual void validate(service::storage_proxy& proxy, const service::client_state& state) override; - const std::vector>& get_statements(); + const std::vector& get_statements(); private: future> get_mutations(service::storage_proxy& storage, const query_options& options, bool local, api::timestamp_type now, tracing::trace_state_ptr trace_state); diff --git a/tests/cql_test_env.cc b/tests/cql_test_env.cc index c3080c2dc3..59743618da 100644 --- a/tests/cql_test_env.cc +++ b/tests/cql_test_env.cc @@ -166,7 +166,7 @@ public: options->prepare(prepared->bound_names); auto qs = make_query_state(); - return local_qp().process_statement(stmt, *qs, *options) + return local_qp().process_statement_prepared(std::move(prepared), std::move(id), *qs, *options, true) .finally([options, qs, this] { _core_local.local().client_state.merge(qs->get_client_state()); }); diff --git a/tests/loading_cache_test.cc b/tests/loading_cache_test.cc index 4e320eb8f5..51993f3046 100644 --- a/tests/loading_cache_test.cc +++ b/tests/loading_cache_test.cc @@ -229,6 +229,41 @@ SEASTAR_TEST_CASE(test_loading_cache_loading_same_key) { }); } +SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_key) { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 100s, test_logger); + auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); }); + + prepare().get(); + + loading_cache.get_ptr(0, loader).discard_result().get(); + BOOST_REQUIRE_EQUAL(load_count, 1); + BOOST_REQUIRE(loading_cache.find(0) != loading_cache.end()); + + loading_cache.remove(0); + BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end()); +} + +SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_iterator) { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 100s, test_logger); + auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); }); + + prepare().get(); + + loading_cache.get_ptr(0, loader).discard_result().get(); + BOOST_REQUIRE_EQUAL(load_count, 1); + + auto it = loading_cache.find(0); + + BOOST_REQUIRE(it != loading_cache.end()); + + loading_cache.remove(it); + BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end()); +} + SEASTAR_TEST_CASE(test_loading_cache_loading_different_keys) { return seastar::async([] { using namespace std::chrono; diff --git a/thrift/handler.cc b/thrift/handler.cc index 815a449970..44f0c72811 100644 --- a/thrift/handler.cc +++ b/thrift/handler.cc @@ -999,9 +999,17 @@ public: void execute_prepared_cql3_query(tcxx::function cob, tcxx::function exn_cob, const int32_t itemId, const std::vector & values, const ConsistencyLevel::type consistency) { with_exn_cob(std::move(exn_cob), [&] { - auto prepared = _query_processor.local().get_prepared(cql3::prepared_cache_key_type(itemId)); + cql3::prepared_cache_key_type cache_key(itemId); + bool needs_authorization = false; + + auto prepared = _query_processor.local().get_prepared(_query_state.get_client_state().user().get(), cache_key); if (!prepared) { - throw make_exception("Prepared query with id %d not found", itemId); + needs_authorization = true; + + prepared = _query_processor.local().get_prepared(cache_key); + if (!prepared) { + throw make_exception("Prepared query with id %d not found", itemId); + } } auto stmt = prepared->statement; if (stmt->get_bound_terms() != values.size()) { @@ -1013,7 +1021,7 @@ public: }); auto opts = std::make_unique(cl_from_thrift(consistency), _timeout_config, stdx::nullopt, std::move(bytes_values), false, cql3::query_options::specific_options::DEFAULT, cql_serialization_format::latest()); - auto f = _query_processor.local().process_statement(stmt, _query_state, *opts); + auto f = _query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), _query_state, *opts, needs_authorization); return f.then([cob = std::move(cob), opts = std::move(opts)](auto&& ret) { cql3_result_visitor visitor; ret->accept(visitor); diff --git a/tracing/trace_keyspace_helper.cc b/tracing/trace_keyspace_helper.cc index 63c1d42de9..1aa5236a5f 100644 --- a/tracing/trace_keyspace_helper.cc +++ b/tracing/trace_keyspace_helper.cc @@ -357,7 +357,7 @@ future<> trace_keyspace_helper::apply_events_mutation(lw_shared_ptrsession_id, events_records.size(), records->parent_id, records->my_span_id); - std::vector> modifications(events_records.size(), _events.insert_stmt()); + std::vector modifications(events_records.size(), cql3::statements::batch_statement::single_statement(_events.insert_stmt(), false)); std::vector> values; auto& qp = cql3::get_local_query_processor(); diff --git a/transport/server.cc b/transport/server.cc index e21bf63bf5..c4233d0644 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -910,7 +910,16 @@ future cql_server::connection::process_execute(uint16_t stream, b { cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); - auto prepared = _server._query_processor.local().get_prepared(cache_key); + bool needs_authorization = false; + + // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there + // look for the prepared statement and then authorize it. + auto prepared = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key); + if (!prepared) { + needs_authorization = true; + prepared = _server._query_processor.local().get_prepared(cache_key); + } + if (!prepared) { throw exceptions::prepared_query_not_found_exception(id); } @@ -945,7 +954,7 @@ future cql_server::connection::process_execute(uint16_t stream, b throw exceptions::invalid_request_exception("Invalid amount of bind variables"); } tracing::trace(query_state.get_trace_state(), "Processing a statement"); - return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { + return _server._query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), query_state, options, needs_authorization).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { @@ -964,8 +973,9 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: const auto type = read_byte(buf); const unsigned n = read_short(buf); - std::vector> modifications; + std::vector modifications; std::vector> values; + std::unordered_map pending_authorization_entries; modifications.reserve(n); values.reserve(n); @@ -977,6 +987,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: std::unique_ptr stmt_ptr; cql3::statements::prepared_statement::checked_weak_ptr ps; + bool needs_authorization(kind == 0); switch (kind) { case 0: { @@ -988,10 +999,19 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: case 1: { cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); - ps = _server._query_processor.local().get_prepared(cache_key); + + // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there + // look for the prepared statement and then authorize it. + ps = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key); if (!ps) { - throw exceptions::prepared_query_not_found_exception(id); + ps = _server._query_processor.local().get_prepared(cache_key); + if (!ps) { + throw exceptions::prepared_query_not_found_exception(id); + } + // authorize a particular prepared statement only once + needs_authorization = pending_authorization_entries.emplace(std::move(cache_key), ps->checked_weak_from_this()).second; } + break; } default: @@ -1007,7 +1027,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: ::shared_ptr modif_statement_ptr = static_pointer_cast(ps->statement); tracing::add_table_name(client_state.get_trace_state(), modif_statement_ptr->keyspace(), modif_statement_ptr->column_family()); - modifications.emplace_back(std::move(modif_statement_ptr)); + modifications.emplace_back(std::move(modif_statement_ptr), needs_authorization); std::vector tmp; read_value_view_list(buf, tmp); @@ -1031,7 +1051,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: tracing::trace(client_state.get_trace_state(), "Creating a batch statement"); auto batch = ::make_shared(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats()); - return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) { + return _server._query_processor.local().process_batch(batch, query_state, options, std::move(pending_authorization_entries)).then([this, stream, batch, &query_state] (auto msg) { return this->make_result(stream, msg, query_state.get_trace_state()); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */ diff --git a/utils/loading_cache.hh b/utils/loading_cache.hh index d8a96ce2e9..8ae74de30a 100644 --- a/utils/loading_cache.hh +++ b/utils/loading_cache.hh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -79,6 +80,10 @@ public: value_type& value() noexcept { return _value; } const value_type& value() const noexcept { return _value; } + static const timestamped_val& container_of(const value_type& value) { + return *bi::get_parent_from_member(&value, ×tamped_val::_value); + } + loading_cache_clock_type::time_point last_read() const noexcept { return _last_read; } @@ -95,6 +100,10 @@ public: return _lru_entry_ptr; } + lru_entry* lru_entry_ptr() const noexcept { + return _lru_entry_ptr; + } + private: void touch() noexcept { assert(_lru_entry_ptr); @@ -365,6 +374,11 @@ public: return _timer_reads_gate.close().finally([this] { _timer.cancel(); }); } + template + iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + return boost::make_transform_iterator(set_find(key, std::move(key_hasher_func), std::move(key_equal_func)), _value_extractor_fn); + }; + iterator find(const Key& k) noexcept { return boost::make_transform_iterator(set_find(k), _value_extractor_fn); } @@ -388,6 +402,24 @@ public: }); } + void remove(const Key& k) { + auto it = set_find(k); + if (it == set_end()) { + return; + } + + _lru_list.erase_and_dispose(_lru_list.iterator_to(*it->lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); }); + } + + void remove(iterator it) { + if (it == end()) { + return; + } + + const ts_value_type& val = ts_value_type::container_of(*it); + _lru_list.erase_and_dispose(_lru_list.iterator_to(*val.lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); }); + } + size_t size() const { return _loading_values.size(); } @@ -398,8 +430,7 @@ public: } private: - set_iterator set_find(const Key& k) noexcept { - set_iterator it = _loading_values.find(k); + set_iterator ready_entry_iterator(set_iterator it) { set_iterator end_it = set_end(); if (it == end_it || !it->ready()) { @@ -408,6 +439,17 @@ private: return it; } + template + set_iterator set_find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + return ready_entry_iterator(_loading_values.find(key, std::move(key_hasher_func), std::move(key_equal_func))); + } + + // keep the default non-templated overloads to ease on the compiler for specifications + // that do not require the templated find(). + set_iterator set_find(const Key& key) noexcept { + return ready_entry_iterator(_loading_values.find(key)); + } + set_iterator set_end() noexcept { return _loading_values.end(); } diff --git a/utils/loading_shared_values.hh b/utils/loading_shared_values.hh index bfa6a276ff..55cc7fcae7 100644 --- a/utils/loading_shared_values.hh +++ b/utils/loading_shared_values.hh @@ -107,16 +107,6 @@ private: return bool(_val); } - struct key_eq { - bool operator()(const key_type& k, const entry& c) const { - return EqualPred()(k, c.key()); - } - - bool operator()(const entry& c, const key_type& k) const { - return EqualPred()(c.key(), k); - } - }; - entry(loading_shared_values& parent, key_type k) : _parent(parent), _key(std::move(k)) {} @@ -134,6 +124,17 @@ private: } }; + template + struct key_eq { + bool operator()(const KeyType& k, const entry& c) const { + return KeyEqual()(k, c.key()); + } + + bool operator()(const entry& c, const KeyType& k) const { + return KeyEqual()(c.key(), k); + } + }; + using set_type = bi::unordered_set, bi::compare_hash>; using bi_set_bucket_traits = typename set_type::bucket_traits; using set_iterator = typename set_type::iterator; @@ -216,7 +217,7 @@ public: future get_or_load(const key_type& key, Loader&& loader) noexcept { static_assert(std::is_same, typename futurize>::type>::value, "Bad Loader signature"); try { - auto i = _set.find(key, Hash(), typename entry::key_eq()); + auto i = _set.find(key, Hash(), key_eq()); lw_shared_ptr e; future<> f = make_ready_future<>(); if (i != _set.end()) { @@ -280,12 +281,19 @@ public: return boost::make_transform_iterator(_set.begin(), _value_extractor_fn); } - iterator find(const key_type& key) noexcept { - set_iterator it = _set.find(key, Hash(), typename entry::key_eq()); + template + iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + set_iterator it = _set.find(key, std::move(key_hasher_func), key_eq()); if (it == _set.end() || !it->ready()) { return end(); } return boost::make_transform_iterator(it, _value_extractor_fn); + }; + + // keep the default non-templated overloads to ease on the compiler for specifications + // that do not require the templated find(). + iterator find(const key_type& key) noexcept { + return find(key, Hash(), EqualPred()); } private: