From 5bde36f29e3f6428d649e5b4a6611b844e88a695 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Wed, 25 Apr 2018 14:02:11 -0400 Subject: [PATCH 1/7] cql3::query_processor: properly stop() prepared_statements_cache object prepared_statements_cache has a timer that evicts old entries - it needs to be properly stopped. Signed-off-by: Vlad Zolotarov --- cql3/prepared_statements_cache.hh | 4 ++++ cql3/query_processor.cc | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index c345fe520e..e6ae7e44e4 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -155,6 +155,10 @@ public: size_t memory_footprint() const { return _cache.memory_footprint(); } + + future<> stop() { + return _cache.stop(); + } }; } diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 693349653f..cf240a44d6 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -210,7 +210,7 @@ query_processor::~query_processor() { future<> query_processor::stop() { service::get_local_migration_manager().unregister_listener(_migration_subscriber.get()); - return make_ready_future<>(); + return _prepared_cache.stop(); } future<::shared_ptr> From 34620deee4ed80f68222acac2079a8b89dd15c23 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Tue, 15 May 2018 14:59:03 -0400 Subject: [PATCH 2/7] utils::loading_cache: add remove(key)/remove(iterator) methods remove(key): removes the entry with the given key if exists, otherwise does nothing. remote(iterator): removes an entry by a given iterator (returned from loading_cache::find()). Signed-off-by: Vlad Zolotarov --- utils/loading_cache.hh | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/utils/loading_cache.hh b/utils/loading_cache.hh index d8a96ce2e9..cf718784ec 100644 --- a/utils/loading_cache.hh +++ b/utils/loading_cache.hh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -79,6 +80,10 @@ public: value_type& value() noexcept { return _value; } const value_type& value() const noexcept { return _value; } + static const timestamped_val& container_of(const value_type& value) { + return *bi::get_parent_from_member(&value, ×tamped_val::_value); + } + loading_cache_clock_type::time_point last_read() const noexcept { return _last_read; } @@ -95,6 +100,10 @@ public: return _lru_entry_ptr; } + lru_entry* lru_entry_ptr() const noexcept { + return _lru_entry_ptr; + } + private: void touch() noexcept { assert(_lru_entry_ptr); @@ -388,6 +397,24 @@ public: }); } + void remove(const Key& k) { + auto it = set_find(k); + if (it == set_end()) { + return; + } + + _lru_list.erase_and_dispose(_lru_list.iterator_to(*it->lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); }); + } + + void remove(iterator it) { + if (it == end()) { + return; + } + + const ts_value_type& val = ts_value_type::container_of(*it); + _lru_list.erase_and_dispose(_lru_list.iterator_to(*val.lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); }); + } + size_t size() const { return _loading_values.size(); } From ab251a1fc32f1670a8bf14b06934ff8b11ae813c Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Tue, 15 May 2018 20:13:35 -0400 Subject: [PATCH 3/7] tests: loading_cache_test: add a tests for a loading_cache::remove(key)/remove(iterator) Signed-off-by: Vlad Zolotarov --- tests/loading_cache_test.cc | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/loading_cache_test.cc b/tests/loading_cache_test.cc index 4e320eb8f5..51993f3046 100644 --- a/tests/loading_cache_test.cc +++ b/tests/loading_cache_test.cc @@ -229,6 +229,41 @@ SEASTAR_TEST_CASE(test_loading_cache_loading_same_key) { }); } +SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_key) { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 100s, test_logger); + auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); }); + + prepare().get(); + + loading_cache.get_ptr(0, loader).discard_result().get(); + BOOST_REQUIRE_EQUAL(load_count, 1); + BOOST_REQUIRE(loading_cache.find(0) != loading_cache.end()); + + loading_cache.remove(0); + BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end()); +} + +SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_iterator) { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 100s, test_logger); + auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); }); + + prepare().get(); + + loading_cache.get_ptr(0, loader).discard_result().get(); + BOOST_REQUIRE_EQUAL(load_count, 1); + + auto it = loading_cache.find(0); + + BOOST_REQUIRE(it != loading_cache.end()); + + loading_cache.remove(it); + BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end()); +} + SEASTAR_TEST_CASE(test_loading_cache_loading_different_keys) { return seastar::async([] { using namespace std::chrono; From 3114cef42caf6218a4198ed876e30971bafe0ec4 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Mon, 21 May 2018 16:22:29 -0400 Subject: [PATCH 4/7] loading_shared_values: introduce the templated find() overload This overload alows searching the elements by an arbitrary key as long as it is "hashable" to the same values as the default key and if there is a comparator for this new key. Signed-off-by: Vlad Zolotarov --- utils/loading_cache.hh | 19 +++++++++++++++++-- utils/loading_shared_values.hh | 34 +++++++++++++++++++++------------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/utils/loading_cache.hh b/utils/loading_cache.hh index cf718784ec..8ae74de30a 100644 --- a/utils/loading_cache.hh +++ b/utils/loading_cache.hh @@ -374,6 +374,11 @@ public: return _timer_reads_gate.close().finally([this] { _timer.cancel(); }); } + template + iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + return boost::make_transform_iterator(set_find(key, std::move(key_hasher_func), std::move(key_equal_func)), _value_extractor_fn); + }; + iterator find(const Key& k) noexcept { return boost::make_transform_iterator(set_find(k), _value_extractor_fn); } @@ -425,8 +430,7 @@ public: } private: - set_iterator set_find(const Key& k) noexcept { - set_iterator it = _loading_values.find(k); + set_iterator ready_entry_iterator(set_iterator it) { set_iterator end_it = set_end(); if (it == end_it || !it->ready()) { @@ -435,6 +439,17 @@ private: return it; } + template + set_iterator set_find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + return ready_entry_iterator(_loading_values.find(key, std::move(key_hasher_func), std::move(key_equal_func))); + } + + // keep the default non-templated overloads to ease on the compiler for specifications + // that do not require the templated find(). + set_iterator set_find(const Key& key) noexcept { + return ready_entry_iterator(_loading_values.find(key)); + } + set_iterator set_end() noexcept { return _loading_values.end(); } diff --git a/utils/loading_shared_values.hh b/utils/loading_shared_values.hh index bfa6a276ff..55cc7fcae7 100644 --- a/utils/loading_shared_values.hh +++ b/utils/loading_shared_values.hh @@ -107,16 +107,6 @@ private: return bool(_val); } - struct key_eq { - bool operator()(const key_type& k, const entry& c) const { - return EqualPred()(k, c.key()); - } - - bool operator()(const entry& c, const key_type& k) const { - return EqualPred()(c.key(), k); - } - }; - entry(loading_shared_values& parent, key_type k) : _parent(parent), _key(std::move(k)) {} @@ -134,6 +124,17 @@ private: } }; + template + struct key_eq { + bool operator()(const KeyType& k, const entry& c) const { + return KeyEqual()(k, c.key()); + } + + bool operator()(const entry& c, const KeyType& k) const { + return KeyEqual()(c.key(), k); + } + }; + using set_type = bi::unordered_set, bi::compare_hash>; using bi_set_bucket_traits = typename set_type::bucket_traits; using set_iterator = typename set_type::iterator; @@ -216,7 +217,7 @@ public: future get_or_load(const key_type& key, Loader&& loader) noexcept { static_assert(std::is_same, typename futurize>::type>::value, "Bad Loader signature"); try { - auto i = _set.find(key, Hash(), typename entry::key_eq()); + auto i = _set.find(key, Hash(), key_eq()); lw_shared_ptr e; future<> f = make_ready_future<>(); if (i != _set.end()) { @@ -280,12 +281,19 @@ public: return boost::make_transform_iterator(_set.begin(), _value_extractor_fn); } - iterator find(const key_type& key) noexcept { - set_iterator it = _set.find(key, Hash(), typename entry::key_eq()); + template + iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept { + set_iterator it = _set.find(key, std::move(key_hasher_func), key_eq()); if (it == _set.end() || !it->ready()) { return end(); } return boost::make_transform_iterator(it, _value_extractor_fn); + }; + + // keep the default non-templated overloads to ease on the compiler for specifications + // that do not require the templated find(). + iterator find(const key_type& key) noexcept { + return find(key, Hash(), EqualPred()); } private: From a138c59991516806ba99cb7714383e56c8deb0e4 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Fri, 11 May 2018 18:58:21 -0400 Subject: [PATCH 5/7] cql3: introduce the authorized_prepared_statements_cache class Add a cache that would store the checked weak pointer to already authorized prepared statements and which key is a tuple of an authenticated_user and key of the prepared_statements_cache. The entries will be held as long as the corresponding prepared statement is valid (cached) and will be discarded with the period equal to the refresh period of the permissions cache. Entries are also going to be discarded after 60 minutes if not used. The purpose of this new cache is to save the lookup in the permissions cache for already authenticated resource (whatever is needed to be authenticated for the particular prepared statement). This is meant to improve the cache coherency as well (since we are going to look in a single cache instead of two). Signed-off-by: Vlad Zolotarov --- cql3/authorized_prepared_statements_cache.hh | 189 +++++++++++++++++++ cql3/query_processor.cc | 88 ++++++--- cql3/query_processor.hh | 30 ++- tests/cql_test_env.cc | 2 +- thrift/handler.cc | 14 +- transport/server.cc | 13 +- 6 files changed, 308 insertions(+), 28 deletions(-) create mode 100644 cql3/authorized_prepared_statements_cache.hh diff --git a/cql3/authorized_prepared_statements_cache.hh b/cql3/authorized_prepared_statements_cache.hh new file mode 100644 index 0000000000..4eff0b1c4a --- /dev/null +++ b/cql3/authorized_prepared_statements_cache.hh @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2018 ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#pragma once + +#include "cql3/prepared_statements_cache.hh" + +namespace cql3 { + +struct authorized_prepared_statements_cache_size { + size_t operator()(const statements::prepared_statement::checked_weak_ptr& val) { + // TODO: improve the size approximation - most of the entry is occupied by the key here. + return 100; + } +}; + +class authorized_prepared_statements_cache_key { +public: + using cache_key_type = std::pair; +private: + cache_key_type _key; + +public: + authorized_prepared_statements_cache_key(auth::authenticated_user user, cql3::prepared_cache_key_type prepared_cache_key) + : _key(std::move(user), std::move(prepared_cache_key.key())) {} + + cache_key_type& key() { return _key; } + + const cache_key_type& key() const { return _key; } + + bool operator==(const authorized_prepared_statements_cache_key& other) const { + return _key == other._key; + } + + bool operator!=(const authorized_prepared_statements_cache_key& other) const { + return !(*this == other); + } + + static size_t hash(const auth::authenticated_user& user, const cql3::prepared_cache_key_type::cache_key_type& prep_cache_key) { + return utils::hash_combine(std::hash()(user), utils::tuple_hash()(prep_cache_key)); + } +}; + +/// \class authorized_prepared_statements_cache +/// \brief A cache of previously authorized statements. +/// +/// Entries are inserted every time a new statement is authorized. +/// Entries are evicted in any of the following cases: +/// - When the corresponding prepared statement is not valid anymore. +/// - Periodically, with the same period as the permission cache is refreshed. +/// - If the corresponding entry hasn't been used for \ref entry_expiry. +class authorized_prepared_statements_cache { +public: + struct stats { + uint64_t authorized_prepared_statements_cache_evictions = 0; + }; + + static stats& shard_stats() { + static thread_local stats _stats; + return _stats; + } + + struct authorized_prepared_statements_cache_stats_updater { + static void inc_hits() noexcept {} + static void inc_misses() noexcept {} + static void inc_blocks() noexcept {} + static void inc_evictions() noexcept { + ++shard_stats().authorized_prepared_statements_cache_evictions; + } + }; + +private: + using cache_key_type = authorized_prepared_statements_cache_key; + using checked_weak_ptr = typename statements::prepared_statement::checked_weak_ptr; + using cache_type = utils::loading_cache, + std::equal_to, + authorized_prepared_statements_cache_stats_updater>; + + static const std::chrono::minutes entry_expiry; + +public: + using key_type = cache_key_type; + using value_type = checked_weak_ptr; + using entry_is_too_big = typename cache_type::entry_is_too_big; + using iterator = typename cache_type::iterator; + +private: + cache_type _cache; + logging::logger& _logger; + +public: + // Choose the memory budget such that would allow us ~4K entries when a shard gets 1GB of RAM + authorized_prepared_statements_cache(std::chrono::milliseconds entry_refresh, logging::logger& logger) + : _cache(memory::stats().total_memory() / 2560, entry_expiry, entry_refresh, logger, [this] (const key_type& k) { + _cache.remove(k); + return make_ready_future(); + }) + , _logger(logger) + {} + + future<> insert(auth::authenticated_user user, cql3::prepared_cache_key_type prep_cache_key, value_type v) noexcept { + return _cache.get_ptr(key_type(std::move(user), std::move(prep_cache_key)), [v = std::move(v)] (const cache_key_type&) mutable { + return make_ready_future(std::move(v)); + }).discard_result(); + } + + iterator find(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) { + struct key_view { + const auth::authenticated_user& user_ref; + const cql3::prepared_cache_key_type& prep_cache_key_ref; + }; + + struct hasher { + size_t operator()(const key_view& kv) { + return cql3::authorized_prepared_statements_cache_key::hash(kv.user_ref, kv.prep_cache_key_ref.key()); + } + }; + + struct equal { + bool operator()(const key_type& k1, const key_view& k2) { + return k1.key().first == k2.user_ref && k1.key().second == k2.prep_cache_key_ref.key(); + } + + bool operator()(const key_view& k2, const key_type& k1) { + return operator()(k1, k2); + } + }; + + return _cache.find(key_view{user, prep_cache_key}, hasher(), equal()); + } + + iterator end() { + return _cache.end(); + } + + void remove(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) { + iterator it = find(user, prep_cache_key); + _cache.remove(it); + } + + size_t size() const { + return _cache.size(); + } + + size_t memory_footprint() const { + return _cache.memory_footprint(); + } + + future<> stop() { + return _cache.stop(); + } +}; + +} + +namespace std { +template <> +struct hash final { + size_t operator()(const cql3::authorized_prepared_statements_cache_key& k) const { + return cql3::authorized_prepared_statements_cache_key::hash(k.key().first, k.key().second); + } +}; + +inline std::ostream& operator<<(std::ostream& out, const cql3::authorized_prepared_statements_cache_key& k) { + return out << "{ " << k.key().first << ", " << k.key().second << " }"; +} +} \ No newline at end of file diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index cf240a44d6..0b11fef9cc 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -58,12 +58,14 @@ using namespace cql_transport::messages; logging::logger log("query_processor"); logging::logger prep_cache_log("prepared_statements_cache"); +logging::logger authorized_prepared_statements_cache_log("authorized_prepared_statements_cache"); distributed _the_query_processor; const sstring query_processor::CQL_VERSION = "3.3.1"; const std::chrono::minutes prepared_statements_cache::entry_expiry = std::chrono::minutes(60); +const std::chrono::minutes authorized_prepared_statements_cache::entry_expiry = std::chrono::minutes(60); class query_processor::internal_state { service::query_state _qs; @@ -96,7 +98,8 @@ query_processor::query_processor(service::storage_proxy& proxy, distributed query_processor::stop() { service::get_local_migration_manager().unregister_listener(_migration_subscriber.get()); - return _prepared_cache.stop(); + return _authorized_prepared_cache.stop().finally([this] { return _prepared_cache.stop(); }); } future<::shared_ptr> @@ -230,33 +249,60 @@ query_processor::process(const sstring_view& query_string, service::query_state& metrics.regularStatementsExecuted.inc(); #endif tracing::trace(query_state.get_trace_state(), "Processing a statement"); - return process_statement(std::move(cql_statement), query_state, options); + return process_statement_unprepared(std::move(cql_statement), query_state, options); } future<::shared_ptr> -query_processor::process_statement( +query_processor::process_statement_unprepared( ::shared_ptr statement, service::query_state& query_state, const query_options& options) { - return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options]() { - auto& client_state = query_state.get_client_state(); + return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options] () mutable { + return process_authorized_statement(std::move(statement), query_state, options); + }); +} - statement->validate(_proxy, client_state); +future<::shared_ptr> +query_processor::process_statement_prepared( + statements::prepared_statement::checked_weak_ptr prepared, + cql3::prepared_cache_key_type cache_key, + service::query_state& query_state, + const query_options& options, + bool needs_authorization) { - auto fut = make_ready_future<::shared_ptr>(); - if (client_state.is_internal()) { - fut = statement->execute_internal(_proxy, query_state, options); - } else { - fut = statement->execute(_proxy, query_state, options); - } - - return fut.then([statement] (auto msg) { - if (msg) { - return make_ready_future<::shared_ptr>(std::move(msg)); - } - return make_ready_future<::shared_ptr>( - ::make_shared()); + ::shared_ptr statement = prepared->statement; + future<> fut = make_ready_future<>(); + if (needs_authorization) { + fut = statement->check_access(query_state.get_client_state()).then([this, &query_state, prepared = std::move(prepared), cache_key = std::move(cache_key)] () mutable { + return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), std::move(cache_key), std::move(prepared)).handle_exception([this] (auto eptr) { + log.error("failed to cache the entry", eptr); + }); }); + } + + return fut.then([this, statement = std::move(statement), &query_state, &options] () mutable { + return process_authorized_statement(std::move(statement), query_state, options); + }); +} + +future<::shared_ptr> +query_processor::process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options) { + auto& client_state = query_state.get_client_state(); + + statement->validate(_proxy, client_state); + + auto fut = make_ready_future<::shared_ptr>(); + if (client_state.is_internal()) { + fut = statement->execute_internal(_proxy, query_state, options); + } else { + fut = statement->execute(_proxy, query_state, options); + } + + return fut.then([statement] (auto msg) { + if (msg) { + return make_ready_future<::shared_ptr>(std::move(msg)); + } + return make_ready_future<::shared_ptr>(::make_shared()); }); } diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 1c6a2b1b18..8ea6a539fe 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -49,6 +49,7 @@ #include #include "cql3/prepared_statements_cache.hh" +#include "cql3/authorized_prepared_statements_cache.hh" #include "cql3/query_options.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/statements/raw/parsed_statement.hh" @@ -117,6 +118,7 @@ private: std::unique_ptr _internal_state; prepared_statements_cache _prepared_cache; + authorized_prepared_statements_cache _authorized_prepared_cache; // A map for prepared statements used internally (which we don't want to mix with user statement, in particular we // don't bother with expiration on those. @@ -151,6 +153,21 @@ public: return _cql_stats; } + statements::prepared_statement::checked_weak_ptr get_prepared(const auth::authenticated_user* user_ptr, const prepared_cache_key_type& key) { + if (user_ptr) { + auto it = _authorized_prepared_cache.find(*user_ptr, key); + if (it != _authorized_prepared_cache.end()) { + try { + return it->get()->checked_weak_from_this(); + } catch (seastar::checked_ptr_is_null_exception&) { + // If the prepared statement got invalidated - remove the corresponding authorized_prepared_statements_cache entry as well. + _authorized_prepared_cache.remove(*user_ptr, key); + } + } + } + return statements::prepared_statement::checked_weak_ptr(); + } + statements::prepared_statement::checked_weak_ptr get_prepared(const prepared_cache_key_type& key) { auto it = _prepared_cache.find(key); if (it == _prepared_cache.end()) { @@ -160,11 +177,19 @@ public: } future<::shared_ptr> - process_statement( + process_statement_unprepared( ::shared_ptr statement, service::query_state& query_state, const query_options& options); + future<::shared_ptr> + process_statement_prepared( + statements::prepared_statement::checked_weak_ptr statement, + cql3::prepared_cache_key_type cache_key, + service::query_state& query_state, + const query_options& options, + bool needs_authorization); + future<::shared_ptr> process( const std::experimental::string_view& query_string, @@ -257,6 +282,9 @@ private: db::consistency_level = db::consistency_level::ONE, int32_t page_size = -1); + future<::shared_ptr> + process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options); + /*! * \brief created a state object for paging * diff --git a/tests/cql_test_env.cc b/tests/cql_test_env.cc index c3080c2dc3..59743618da 100644 --- a/tests/cql_test_env.cc +++ b/tests/cql_test_env.cc @@ -166,7 +166,7 @@ public: options->prepare(prepared->bound_names); auto qs = make_query_state(); - return local_qp().process_statement(stmt, *qs, *options) + return local_qp().process_statement_prepared(std::move(prepared), std::move(id), *qs, *options, true) .finally([options, qs, this] { _core_local.local().client_state.merge(qs->get_client_state()); }); diff --git a/thrift/handler.cc b/thrift/handler.cc index 815a449970..44f0c72811 100644 --- a/thrift/handler.cc +++ b/thrift/handler.cc @@ -999,9 +999,17 @@ public: void execute_prepared_cql3_query(tcxx::function cob, tcxx::function exn_cob, const int32_t itemId, const std::vector & values, const ConsistencyLevel::type consistency) { with_exn_cob(std::move(exn_cob), [&] { - auto prepared = _query_processor.local().get_prepared(cql3::prepared_cache_key_type(itemId)); + cql3::prepared_cache_key_type cache_key(itemId); + bool needs_authorization = false; + + auto prepared = _query_processor.local().get_prepared(_query_state.get_client_state().user().get(), cache_key); if (!prepared) { - throw make_exception("Prepared query with id %d not found", itemId); + needs_authorization = true; + + prepared = _query_processor.local().get_prepared(cache_key); + if (!prepared) { + throw make_exception("Prepared query with id %d not found", itemId); + } } auto stmt = prepared->statement; if (stmt->get_bound_terms() != values.size()) { @@ -1013,7 +1021,7 @@ public: }); auto opts = std::make_unique(cl_from_thrift(consistency), _timeout_config, stdx::nullopt, std::move(bytes_values), false, cql3::query_options::specific_options::DEFAULT, cql_serialization_format::latest()); - auto f = _query_processor.local().process_statement(stmt, _query_state, *opts); + auto f = _query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), _query_state, *opts, needs_authorization); return f.then([cob = std::move(cob), opts = std::move(opts)](auto&& ret) { cql3_result_visitor visitor; ret->accept(visitor); diff --git a/transport/server.cc b/transport/server.cc index e21bf63bf5..0cd3e6060b 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -910,7 +910,16 @@ future cql_server::connection::process_execute(uint16_t stream, b { cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); - auto prepared = _server._query_processor.local().get_prepared(cache_key); + bool needs_authorization = false; + + // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there + // look for the prepared statement and then authorize it. + auto prepared = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key); + if (!prepared) { + needs_authorization = true; + prepared = _server._query_processor.local().get_prepared(cache_key); + } + if (!prepared) { throw exceptions::prepared_query_not_found_exception(id); } @@ -945,7 +954,7 @@ future cql_server::connection::process_execute(uint16_t stream, b throw exceptions::invalid_request_exception("Invalid amount of bind variables"); } tracing::trace(query_state.get_trace_state(), "Processing a statement"); - return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { + return _server._query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), query_state, options, needs_authorization).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) { tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result"); return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { From 9723988926f27135b6526c3e2c48213a9eb3f9b5 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Thu, 17 May 2018 19:23:33 -0400 Subject: [PATCH 6/7] cql3::statements::batch_statement: introduce a single_statement class This is a helper class needed to control the handling process of a single statement in the current batch. In particular it has the boolean defining if the authorization is needed for this statement. Signed-off-by: Vlad Zolotarov --- cql3/statements/batch_statement.cc | 42 +++++++++++++++++------------- cql3/statements/batch_statement.hh | 22 +++++++++++++--- tracing/trace_keyspace_helper.cc | 2 +- transport/server.cc | 2 +- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/cql3/statements/batch_statement.cc b/cql3/statements/batch_statement.cc index 96d9db1f13..8916781b95 100644 --- a/cql3/statements/batch_statement.cc +++ b/cql3/statements/batch_statement.cc @@ -75,19 +75,19 @@ timeout_for_type(batch_statement::type t) { } batch_statement::batch_statement(int bound_terms, type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats) : cql_statement_no_metadata(timeout_for_type(type_)) , _bound_terms(bound_terms), _type(type_), _statements(std::move(statements)) , _attrs(std::move(attrs)) - , _has_conditions(boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::has_conditions))) + , _has_conditions(boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->has_conditions(); })) , _stats(stats) { } batch_statement::batch_statement(type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats) : batch_statement(-1, type_, std::move(statements), std::move(attrs), stats) @@ -97,7 +97,7 @@ batch_statement::batch_statement(type type_, bool batch_statement::uses_function(const sstring& ks_name, const sstring& function_name) const { return _attrs->uses_function(ks_name, function_name) - || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s->uses_function(ks_name, function_name); }); + || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s.statement->uses_function(ks_name, function_name); }); } bool batch_statement::depends_on_keyspace(const sstring& ks_name) const @@ -118,7 +118,11 @@ uint32_t batch_statement::get_bound_terms() future<> batch_statement::check_access(const service::client_state& state) { return parallel_for_each(_statements.begin(), _statements.end(), [&state](auto&& s) { - return s->check_access(state); + if (s.needs_authorization) { + return s.statement->check_access(state); + } else { + return make_ready_future<>(); + } }); } @@ -138,12 +142,12 @@ void batch_statement::validate() } } - bool has_counters = boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_counter)); - bool has_non_counters = !boost::algorithm::all_of(_statements, std::mem_fn(&modification_statement::is_counter)); + bool has_counters = boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_counter(); }); + bool has_non_counters = !boost::algorithm::all_of(_statements, [] (auto&& s) { return s.statement->is_counter(); }); if (timestamp_set && has_counters) { throw exceptions::invalid_request_exception("Cannot provide custom timestamp for a BATCH containing counters"); } - if (timestamp_set && boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_timestamp_set))) { + if (timestamp_set && boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_timestamp_set(); })) { throw exceptions::invalid_request_exception("Timestamp must be set either on BATCH or individual statements"); } if (_type == type::COUNTER && has_non_counters) { @@ -159,30 +163,30 @@ void batch_statement::validate() if (_has_conditions && !_statements.empty() && (boost::distance(_statements - | boost::adaptors::transformed(std::mem_fn(&modification_statement::keyspace)) + | boost::adaptors::transformed([] (auto&& s) { return s.statement->keyspace(); }) | boost::adaptors::uniqued) != 1 || (boost::distance(_statements - | boost::adaptors::transformed(std::mem_fn(&modification_statement::column_family)) + | boost::adaptors::transformed([] (auto&& s) { return s.statement->column_family(); }) | boost::adaptors::uniqued) != 1))) { throw exceptions::invalid_request_exception("Batch with conditions cannot span multiple tables"); } std::experimental::optional raw_counter; for (auto& s : _statements) { - if (raw_counter && s->is_raw_counter_shard_write() != *raw_counter) { + if (raw_counter && s.statement->is_raw_counter_shard_write() != *raw_counter) { throw exceptions::invalid_request_exception("Cannot mix raw and regular counter statements in batch"); } - raw_counter = s->is_raw_counter_shard_write(); + raw_counter = s.statement->is_raw_counter_shard_write(); } } void batch_statement::validate(service::storage_proxy& proxy, const service::client_state& state) { for (auto&& s : _statements) { - s->validate(proxy, state); + s.statement->validate(proxy, state); } } -const std::vector>& batch_statement::get_statements() +const std::vector& batch_statement::get_statements() { return _statements; } @@ -196,7 +200,7 @@ future> batch_statement::get_mutations(service::storage_pr return do_for_each(boost::make_counting_iterator(0), boost::make_counting_iterator(_statements.size()), [this, &storage, &options, now, local, &result, trace_state] (size_t i) { - auto&& statement = _statements[i]; + auto&& statement = _statements[i].statement; statement->inc_cql_stats(); auto&& statement_options = options.for_statement(i); auto timestamp = _attrs->get_timestamp(now, statement_options); @@ -426,7 +430,9 @@ batch_statement::prepare(database& db, cql_stats& stats) { stdx::optional first_cf; bool have_multiple_cfs = false; - std::vector> statements; + std::vector statements; + statements.reserve(_parsed_statements.size()); + for (auto&& parsed : _parsed_statements) { if (!first_ks) { first_ks = parsed->keyspace(); @@ -434,7 +440,7 @@ batch_statement::prepare(database& db, cql_stats& stats) { } else { have_multiple_cfs = first_ks.value() != parsed->keyspace() || first_cf.value() != parsed->column_family(); } - statements.push_back(parsed->prepare(db, bound_names, stats)); + statements.emplace_back(parsed->prepare(db, bound_names, stats)); } auto&& prep_attrs = _attrs->prepare(db, "[batch]", "[batch]"); @@ -445,7 +451,7 @@ batch_statement::prepare(database& db, cql_stats& stats) { std::vector partition_key_bind_indices; if (!have_multiple_cfs && batch_statement_.get_statements().size() > 0) { - partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0]->s); + partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0].statement->s); } return std::make_unique(make_shared(std::move(batch_statement_)), bound_names->get_specifications(), diff --git a/cql3/statements/batch_statement.hh b/cql3/statements/batch_statement.hh index cc0d4595b9..f31d24e79f 100644 --- a/cql3/statements/batch_statement.hh +++ b/cql3/statements/batch_statement.hh @@ -66,10 +66,24 @@ class batch_statement : public cql_statement_no_metadata { static logging::logger _logger; public: using type = raw::batch_statement::type; + + struct single_statement { + shared_ptr statement; + bool needs_authorization = true; + + public: + single_statement(shared_ptr s) + : statement(std::move(s)) + {} + single_statement(shared_ptr s, bool na) + : statement(std::move(s)) + , needs_authorization(na) + {} + }; private: int _bound_terms; type _type; - std::vector> _statements; + std::vector _statements; std::unique_ptr _attrs; bool _has_conditions; cql_stats& _stats; @@ -83,12 +97,12 @@ public: * @param attrs additional attributes for statement (CL, timestamp, timeToLive) */ batch_statement(int bound_terms, type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats); batch_statement(type type_, - std::vector> statements, + std::vector statements, std::unique_ptr attrs, cql_stats& stats); @@ -109,7 +123,7 @@ public: // or in QueryProcessor.processBatch() - for native protocol batches. virtual void validate(service::storage_proxy& proxy, const service::client_state& state) override; - const std::vector>& get_statements(); + const std::vector& get_statements(); private: future> get_mutations(service::storage_proxy& storage, const query_options& options, bool local, api::timestamp_type now, tracing::trace_state_ptr trace_state); diff --git a/tracing/trace_keyspace_helper.cc b/tracing/trace_keyspace_helper.cc index 63c1d42de9..1aa5236a5f 100644 --- a/tracing/trace_keyspace_helper.cc +++ b/tracing/trace_keyspace_helper.cc @@ -357,7 +357,7 @@ future<> trace_keyspace_helper::apply_events_mutation(lw_shared_ptrsession_id, events_records.size(), records->parent_id, records->my_span_id); - std::vector> modifications(events_records.size(), _events.insert_stmt()); + std::vector modifications(events_records.size(), cql3::statements::batch_statement::single_statement(_events.insert_stmt(), false)); std::vector> values; auto& qp = cql3::get_local_query_processor(); diff --git a/transport/server.cc b/transport/server.cc index 0cd3e6060b..b19e057542 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -973,7 +973,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: const auto type = read_byte(buf); const unsigned n = read_short(buf); - std::vector> modifications; + std::vector modifications; std::vector> values; modifications.reserve(n); From 82f7d1d006b42ce63958486e6498a401ade23607 Mon Sep 17 00:00:00 2001 From: Vlad Zolotarov Date: Thu, 17 May 2018 20:27:14 -0400 Subject: [PATCH 7/7] cql3: use authorized_prepared_statements_cache in the BATCH processing Like with the EXECUTE command avoid authorizing the same prepared statement twice - this time in the context of processing the BATCH command. Signed-off-by: Vlad Zolotarov --- cql3/prepared_statements_cache.hh | 15 +++++++++++++++ cql3/query_processor.cc | 17 ++++++++++++----- cql3/query_processor.hh | 6 +++++- transport/server.cc | 19 +++++++++++++++---- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index e6ae7e44e4..338186aff0 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -68,6 +68,14 @@ public: static thrift_prepared_id_type thrift_id(const prepared_cache_key_type& key) { return key.key().second; } + + bool operator==(const prepared_cache_key_type& other) const { + return _key == other._key; + } + + bool operator!=(const prepared_cache_key_type& other) const { + return !(*this == other); + } }; class prepared_statements_cache { @@ -172,4 +180,11 @@ inline std::ostream& operator<<(std::ostream& os, const cql3::prepared_cache_key os << p.key(); return os; } + +template<> +struct hash final { + size_t operator()(const cql3::prepared_cache_key_type& k) const { + return utils::tuple_hash()(k.key()); + } +}; } diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index 0b11fef9cc..a6c657b7ac 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -604,11 +604,18 @@ future<::shared_ptr> query_processor::process_batch( ::shared_ptr batch, service::query_state& query_state, - query_options& options) { - return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch] { - batch->validate(); - batch->validate(_proxy, query_state.get_client_state()); - return batch->execute(_proxy, query_state, options); + query_options& options, + std::unordered_map pending_authorization_entries) { + return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch, pending_authorization_entries = std::move(pending_authorization_entries)] () mutable { + return parallel_for_each(pending_authorization_entries, [this, &query_state] (auto& e) { + return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), e.first, std::move(e.second)).handle_exception([this] (auto eptr) { + log.error("failed to cache the entry", eptr); + }); + }).then([this, &query_state, &options, batch] { + batch->validate(); + batch->validate(_proxy, query_state.get_client_state()); + return batch->execute(_proxy, query_state, options); + }); }); } diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 8ea6a539fe..2f671e0c96 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -267,7 +267,11 @@ public: future<> stop(); future<::shared_ptr> - process_batch(::shared_ptr, service::query_state& query_state, query_options& options); + process_batch( + ::shared_ptr, + service::query_state& query_state, + query_options& options, + std::unordered_map pending_authorization_entries); std::unique_ptr get_statement( const std::experimental::string_view& query, diff --git a/transport/server.cc b/transport/server.cc index b19e057542..c4233d0644 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -975,6 +975,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: std::vector modifications; std::vector> values; + std::unordered_map pending_authorization_entries; modifications.reserve(n); values.reserve(n); @@ -986,6 +987,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: std::unique_ptr stmt_ptr; cql3::statements::prepared_statement::checked_weak_ptr ps; + bool needs_authorization(kind == 0); switch (kind) { case 0: { @@ -997,10 +999,19 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: case 1: { cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); - ps = _server._query_processor.local().get_prepared(cache_key); + + // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there + // look for the prepared statement and then authorize it. + ps = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key); if (!ps) { - throw exceptions::prepared_query_not_found_exception(id); + ps = _server._query_processor.local().get_prepared(cache_key); + if (!ps) { + throw exceptions::prepared_query_not_found_exception(id); + } + // authorize a particular prepared statement only once + needs_authorization = pending_authorization_entries.emplace(std::move(cache_key), ps->checked_weak_from_this()).second; } + break; } default: @@ -1016,7 +1027,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: ::shared_ptr modif_statement_ptr = static_pointer_cast(ps->statement); tracing::add_table_name(client_state.get_trace_state(), modif_statement_ptr->keyspace(), modif_statement_ptr->column_family()); - modifications.emplace_back(std::move(modif_statement_ptr)); + modifications.emplace_back(std::move(modif_statement_ptr), needs_authorization); std::vector tmp; read_value_view_list(buf, tmp); @@ -1040,7 +1051,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: tracing::trace(client_state.get_trace_state(), "Creating a batch statement"); auto batch = ::make_shared(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats()); - return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) { + return _server._query_processor.local().process_batch(batch, query_state, options, std::move(pending_authorization_entries)).then([this, stream, batch, &query_state] (auto msg) { return this->make_result(stream, msg, query_state.get_trace_state()); }).then([&query_state, q_state = std::move(q_state), this] (auto&& response) { /* Keep q_state alive. */