diff --git a/auth/auth.cc b/auth/auth.cc index c3471d45d8..50900b0e0d 100644 --- a/auth/auth.cc +++ b/auth/auth.cc @@ -114,7 +114,7 @@ struct hash { class auth::auth::permissions_cache { public: - typedef utils::loading_cache, permission_set, utils::tuple_hash> cache_type; + typedef utils::loading_cache, permission_set, utils::loading_cache_reload_enabled::yes, utils::simple_entry_size, utils::tuple_hash> cache_type; typedef typename cache_type::key_type key_type; permissions_cache() diff --git a/configure.py b/configure.py index 671a7fe961..9c672ee6d2 100755 --- a/configure.py +++ b/configure.py @@ -238,6 +238,7 @@ scylla_tests = [ 'tests/view_schema_test', 'tests/counter_test', 'tests/cell_locker_test', + 'tests/loading_cache_test', ] apps = [ diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh new file mode 100644 index 0000000000..c345fe520e --- /dev/null +++ b/cql3/prepared_statements_cache.hh @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2017 ScyllaDB + * + * Modified by ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#pragma once + +#include "utils/loading_cache.hh" +#include "cql3/statements/prepared_statement.hh" + +namespace cql3 { + +using prepared_cache_entry = std::unique_ptr; + +struct prepared_cache_entry_size { + size_t operator()(const prepared_cache_entry& val) { + // TODO: improve the size approximation + return 10000; + } +}; + +typedef bytes cql_prepared_id_type; +typedef int32_t thrift_prepared_id_type; + +/// \brief The key of the prepared statements cache +/// +/// We are going to store the CQL and Thrift prepared statements in the same cache therefore we need generate the key +/// that is going to be unique in both cases. Thrift use int32_t as a prepared statement ID, CQL - MD5 digest. +/// +/// We are going to use an std::pair as a key. For CQL statements we will use {CQL_PREP_ID, std::numeric_limits::max()} as a key +/// and for Thrift - {CQL_PREP_ID_TYPE(0), THRIFT_PREP_ID}. This way CQL and Thrift keys' values will never collide. +class prepared_cache_key_type { +public: + using cache_key_type = std::pair; + +private: + cache_key_type _key; + +public: + prepared_cache_key_type() = default; + explicit prepared_cache_key_type(cql_prepared_id_type cql_id) : _key(std::move(cql_id), std::numeric_limits::max()) {} + explicit prepared_cache_key_type(thrift_prepared_id_type thrift_id) : _key(cql_prepared_id_type(), thrift_id) {} + + cache_key_type& key() { return _key; } + const cache_key_type& key() const { return _key; } + + static const cql_prepared_id_type& cql_id(const prepared_cache_key_type& key) { + return key.key().first; + } + static thrift_prepared_id_type thrift_id(const prepared_cache_key_type& key) { + return key.key().second; + } +}; + +class prepared_statements_cache { +public: + struct stats { + uint64_t prepared_cache_evictions = 0; + }; + + static stats& shard_stats() { + static thread_local stats _stats; + return _stats; + } + + struct prepared_cache_stats_updater { + static void inc_hits() noexcept {} + static void inc_misses() noexcept {} + static void inc_blocks() noexcept {} + static void inc_evictions() noexcept { + ++shard_stats().prepared_cache_evictions; + } + }; + +private: + using cache_key_type = typename prepared_cache_key_type::cache_key_type; + using cache_type = utils::loading_cache, prepared_cache_stats_updater>; + using cache_value_ptr = typename cache_type::value_ptr; + using cache_iterator = typename cache_type::iterator; + using checked_weak_ptr = typename statements::prepared_statement::checked_weak_ptr; + struct value_extractor_fn { + checked_weak_ptr operator()(prepared_cache_entry& e) const { + return e->checked_weak_from_this(); + } + }; + + static const std::chrono::minutes entry_expiry; + +public: + using key_type = prepared_cache_key_type; + using value_type = checked_weak_ptr; + using statement_is_too_big = typename cache_type::entry_is_too_big; + /// \note both iterator::reference and iterator::value_type are checked_weak_ptr + using iterator = boost::transform_iterator; + +private: + cache_type _cache; + value_extractor_fn _value_extractor_fn; + +public: + prepared_statements_cache(logging::logger& logger) + : _cache(memory::stats().total_memory() / 256, entry_expiry, logger) + {} + + template + future get(const key_type& key, LoadFunc&& load) { + return _cache.get_ptr(key.key(), [load = std::forward(load)] (const cache_key_type&) { return load(); }).then([] (cache_value_ptr v_ptr) { + return make_ready_future((*v_ptr)->checked_weak_from_this()); + }); + } + + iterator find(const key_type& key) { + return boost::make_transform_iterator(_cache.find(key.key()), _value_extractor_fn); + } + + iterator end() { + return boost::make_transform_iterator(_cache.end(), _value_extractor_fn); + } + + iterator begin() { + return boost::make_transform_iterator(_cache.begin(), _value_extractor_fn); + } + + template + void remove_if(Pred&& pred) { + static_assert(std::is_same)>>::value, "Bad Pred signature"); + + _cache.remove_if([&pred] (const prepared_cache_entry& e) { + return pred(e->statement); + }); + } + + size_t size() const { + return _cache.size(); + } + + size_t memory_footprint() const { + return _cache.memory_footprint(); + } +}; +} + +namespace std { // for prepared_statements_cache log printouts +inline std::ostream& operator<<(std::ostream& os, const typename cql3::prepared_cache_key_type::cache_key_type& p) { + os << "{cql_id: " << p.first << ", thrift_id: " << p.second << "}"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const cql3::prepared_cache_key_type& p) { + os << p.key(); + return os; +} +} diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index e15f5cfa40..25bc5a665f 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -57,11 +57,14 @@ using namespace statements; using namespace cql_transport::messages; logging::logger log("query_processor"); +logging::logger prep_cache_log("prepared_statements_cache"); distributed _the_query_processor; const sstring query_processor::CQL_VERSION = "3.3.1"; +const std::chrono::minutes prepared_statements_cache::entry_expiry = std::chrono::minutes(60); + class query_processor::internal_state { service::query_state _qs; public: @@ -95,6 +98,7 @@ query_processor::query_processor(distributed& proxy, , _proxy(proxy) , _db(db) , _internal_state(new internal_state()) + , _prepared_cache(prep_cache_log) { namespace sm = seastar::metrics; @@ -130,6 +134,15 @@ query_processor::query_processor(distributed& proxy, sm::make_derive("batches_unlogged_from_logged", _cql_stats.batches_unlogged_from_logged, sm::description("Counts a total number of LOGGED batches that were executed as UNLOGGED batches.")), + + sm::make_derive("prepared_cache_evictions", [] { return prepared_statements_cache::shard_stats().prepared_cache_evictions; }, + sm::description("Counts a number of prepared statements cache entries evictions.")), + + sm::make_gauge("prepared_cache_size", [this] { return _prepared_cache.size(); }, + sm::description("A number of entries in the prepared statements cache.")), + + sm::make_gauge("prepared_cache_memory_footprint", [this] { return _prepared_cache.memory_footprint(); }, + sm::description("Size (in bytes) of the prepared statements cache.")), }); service::get_local_migration_manager().register_listener(_migration_subscriber.get()); @@ -197,31 +210,21 @@ query_processor::process_statement(::shared_ptr statement, } future<::shared_ptr> -query_processor::prepare(const std::experimental::string_view& query_string, service::query_state& query_state) +query_processor::prepare(sstring query_string, service::query_state& query_state) { auto& client_state = query_state.get_client_state(); - return prepare(query_string, client_state, client_state.is_thrift()); + return prepare(std::move(query_string), client_state, client_state.is_thrift()); } future<::shared_ptr> -query_processor::prepare(const std::experimental::string_view& query_string, - const service::client_state& client_state, - bool for_thrift) +query_processor::prepare(sstring query_string, const service::client_state& client_state, bool for_thrift) { - auto existing = get_stored_prepared_statement(query_string, client_state.get_raw_keyspace(), for_thrift); - if (existing) { - return make_ready_future<::shared_ptr>(existing); + using namespace cql_transport::messages; + if (for_thrift) { + return prepare_one(std::move(query_string), client_state, compute_thrift_id, prepared_cache_key_type::thrift_id); + } else { + return prepare_one(std::move(query_string), client_state, compute_id, prepared_cache_key_type::cql_id); } - - return futurize<::shared_ptr>::apply([this, &query_string, &client_state, for_thrift] { - auto prepared = get_statement(query_string, client_state); - auto bound_terms = prepared->statement->get_bound_terms(); - if (bound_terms > std::numeric_limits::max()) { - throw exceptions::invalid_request_exception(sprint("Too many markers(?). %d markers exceed the allowed maximum of %d", bound_terms, std::numeric_limits::max())); - } - assert(bound_terms == prepared->bound_names.size()); - return store_prepared_statement(query_string, client_state.get_raw_keyspace(), std::move(prepared), for_thrift); - }); } ::shared_ptr @@ -229,50 +232,11 @@ query_processor::get_stored_prepared_statement(const std::experimental::string_v const sstring& keyspace, bool for_thrift) { + using namespace cql_transport::messages; if (for_thrift) { - auto statement_id = compute_thrift_id(query_string, keyspace); - auto it = _thrift_prepared_statements.find(statement_id); - if (it == _thrift_prepared_statements.end()) { - return ::shared_ptr(); - } - return ::make_shared(statement_id, it->second->checked_weak_from_this()); + return get_stored_prepared_statement_one(query_string, keyspace, compute_thrift_id, prepared_cache_key_type::thrift_id); } else { - auto statement_id = compute_id(query_string, keyspace); - auto it = _prepared_statements.find(statement_id); - if (it == _prepared_statements.end()) { - return ::shared_ptr(); - } - return ::make_shared(statement_id, it->second->checked_weak_from_this()); - } -} - -future<::shared_ptr> -query_processor::store_prepared_statement(const std::experimental::string_view& query_string, - const sstring& keyspace, - std::unique_ptr prepared, - bool for_thrift) -{ -#if 0 - // Concatenate the current keyspace so we don't mix prepared statements between keyspace (#5352). - // (if the keyspace is null, queryString has to have a fully-qualified keyspace so it's fine. - long statementSize = measure(prepared.statement); - // don't execute the statement if it's bigger than the allowed threshold - if (statementSize > MAX_CACHE_PREPARED_MEMORY) - throw new InvalidRequestException(String.format("Prepared statement of size %d bytes is larger than allowed maximum of %d bytes.", - statementSize, - MAX_CACHE_PREPARED_MEMORY)); -#endif - prepared->raw_cql_statement = query_string.data(); - if (for_thrift) { - auto statement_id = compute_thrift_id(query_string, keyspace); - auto msg = ::make_shared(statement_id, prepared->checked_weak_from_this()); - _thrift_prepared_statements.emplace(statement_id, std::move(prepared)); - return make_ready_future<::shared_ptr>(std::move(msg)); - } else { - auto statement_id = compute_id(query_string, keyspace); - auto msg = ::make_shared(statement_id, prepared->checked_weak_from_this()); - _prepared_statements.emplace(statement_id, std::move(prepared)); - return make_ready_future<::shared_ptr>(std::move(msg)); + return get_stored_prepared_statement_one(query_string, keyspace, compute_id, prepared_cache_key_type::cql_id); } } @@ -289,19 +253,19 @@ static sstring hash_target(const std::experimental::string_view& query_string, c return keyspace + query_string.to_string(); } -bytes query_processor::compute_id(const std::experimental::string_view& query_string, const sstring& keyspace) +prepared_cache_key_type query_processor::compute_id(const std::experimental::string_view& query_string, const sstring& keyspace) { - return md5_calculate(hash_target(query_string, keyspace)); + return prepared_cache_key_type(md5_calculate(hash_target(query_string, keyspace))); } -int32_t query_processor::compute_thrift_id(const std::experimental::string_view& query_string, const sstring& keyspace) +prepared_cache_key_type query_processor::compute_thrift_id(const std::experimental::string_view& query_string, const sstring& keyspace) { auto target = hash_target(query_string, keyspace); uint32_t h = 0; for (auto&& c : hash_target(query_string, keyspace)) { h = 31*h + c; } - return static_cast(h); + return prepared_cache_key_type(static_cast(h)); } std::unique_ptr @@ -527,7 +491,7 @@ void query_processor::migration_subscriber::on_drop_view(const sstring& ks_name, void query_processor::migration_subscriber::remove_invalid_prepared_statements(sstring ks_name, std::experimental::optional cf_name) { - _qp->invalidate_prepared_statements([&] (::shared_ptr stmt) { + _qp->_prepared_cache.remove_if([&] (::shared_ptr stmt) { return this->should_invalidate(ks_name, cf_name, stmt); }); } diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 0e4ed6056c..aad9e32e33 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -57,6 +57,7 @@ #include "statements/prepared_statement.hh" #include "transport/messages/result_message.hh" #include "untyped_result_set.hh" +#include "prepared_statements_cache.hh" namespace cql3 { @@ -64,9 +65,32 @@ namespace statements { class batch_statement; } +class prepared_statement_is_too_big : public std::exception { +public: + static constexpr int max_query_prefix = 100; + +private: + sstring _msg; + +public: + prepared_statement_is_too_big(const sstring& query_string) + : _msg(seastar::format("Prepared statement is too big: {}", query_string.substr(0, max_query_prefix))) + { + // mark that we clipped the query string + if (query_string.size() > max_query_prefix) { + _msg += "..."; + } + } + + virtual const char* what() const noexcept override { + return _msg.c_str(); + } +}; + class query_processor { public: class migration_subscriber; + private: std::unique_ptr _migration_subscriber; distributed& _proxy; @@ -127,9 +151,7 @@ private: } }; #endif - - std::unordered_map> _prepared_statements; - std::unordered_map> _thrift_prepared_statements; + prepared_statements_cache _prepared_cache; std::unordered_map> _internal_statements; #if 0 @@ -221,21 +243,14 @@ private: } #endif public: - statements::prepared_statement::checked_weak_ptr get_prepared(const bytes& id) { - auto it = _prepared_statements.find(id); - if (it == _prepared_statements.end()) { + statements::prepared_statement::checked_weak_ptr get_prepared(const prepared_cache_key_type& key) { + auto it = _prepared_cache.find(key); + if (it == _prepared_cache.end()) { return statements::prepared_statement::checked_weak_ptr(); } - return it->second->checked_weak_from_this(); + return *it; } - statements::prepared_statement::checked_weak_ptr get_prepared_for_thrift(int32_t id) { - auto it = _thrift_prepared_statements.find(id); - if (it == _thrift_prepared_statements.end()) { - return statements::prepared_statement::checked_weak_ptr(); - } - return it->second->checked_weak_from_this(); - } #if 0 public static void validateKey(ByteBuffer key) throws InvalidRequestException { @@ -435,42 +450,61 @@ public: #endif future<::shared_ptr> - prepare(const std::experimental::string_view& query_string, service::query_state& query_state); + prepare(sstring query_string, service::query_state& query_state); future<::shared_ptr> - prepare(const std::experimental::string_view& query_string, const service::client_state& client_state, bool for_thrift); + prepare(sstring query_string, const service::client_state& client_state, bool for_thrift); - static bytes compute_id(const std::experimental::string_view& query_string, const sstring& keyspace); - static int32_t compute_thrift_id(const std::experimental::string_view& query_string, const sstring& keyspace); + static prepared_cache_key_type compute_id(const std::experimental::string_view& query_string, const sstring& keyspace); + static prepared_cache_key_type compute_thrift_id(const std::experimental::string_view& query_string, const sstring& keyspace); private: + /// + /// \tparam ResultMsgType type of the returned result message (CQL or Thrift) + /// \tparam PreparedKeyGenerator a function that generates the prepared statement cache key for given query and keyspace + /// \tparam IdGetter a function that returns the corresponding prepared statement ID (CQL or Thrift) for a given prepared statement cache key + /// \param query_string + /// \param client_state + /// \param id_gen prepared ID generator, called before the first deferring + /// \param id_getter prepared ID getter, passed to deferred context by reference. The caller must ensure its liveness. + /// \return + template + future<::shared_ptr> + prepare_one(sstring query_string, const service::client_state& client_state, PreparedKeyGenerator&& id_gen, IdGetter&& id_getter) { + return do_with(id_gen(query_string, client_state.get_raw_keyspace()), std::move(query_string), [this, &client_state, &id_getter] (const prepared_cache_key_type& key, const sstring& query_string) { + return _prepared_cache.get(key, [this, &query_string, &client_state] { + auto prepared = get_statement(query_string, client_state); + auto bound_terms = prepared->statement->get_bound_terms(); + if (bound_terms > std::numeric_limits::max()) { + throw exceptions::invalid_request_exception(sprint("Too many markers(?). %d markers exceed the allowed maximum of %d", bound_terms, std::numeric_limits::max())); + } + assert(bound_terms == prepared->bound_names.size()); + prepared->raw_cql_statement = query_string; + return make_ready_future>(std::move(prepared)); + }).then([&key, &id_getter] (auto prep_ptr) { + return make_ready_future<::shared_ptr>(::make_shared(id_getter(key), std::move(prep_ptr))); + }).handle_exception_type([&query_string] (typename prepared_statements_cache::statement_is_too_big&) { + return make_exception_future<::shared_ptr>(prepared_statement_is_too_big(query_string)); + }); + }); + }; + + template + ::shared_ptr + get_stored_prepared_statement_one(const std::experimental::string_view& query_string, const sstring& keyspace, KeyGenerator&& key_gen, IdGetter&& id_getter) + { + auto cache_key = key_gen(query_string, keyspace); + auto it = _prepared_cache.find(cache_key); + if (it == _prepared_cache.end()) { + return ::shared_ptr(); + } + + return ::make_shared(id_getter(cache_key), *it); + } + ::shared_ptr get_stored_prepared_statement(const std::experimental::string_view& query_string, const sstring& keyspace, bool for_thrift); - future<::shared_ptr> - store_prepared_statement(const std::experimental::string_view& query_string, const sstring& keyspace, std::unique_ptr prepared, bool for_thrift); - - // Erases the statements for which filter returns true. - template - void invalidate_prepared_statements(Pred filter) { - static_assert(std::is_same)>>::value, - "bad Pred signature"); - for (auto it = _prepared_statements.begin(); it != _prepared_statements.end(); ) { - if (filter(it->second->statement)) { - it = _prepared_statements.erase(it); - } else { - ++it; - } - } - for (auto it = _thrift_prepared_statements.begin(); it != _thrift_prepared_statements.end(); ) { - if (filter(it->second->statement)) { - it = _thrift_prepared_statements.erase(it); - } else { - ++it; - } - } - } - #if 0 public ResultMessage processPrepared(CQLStatement statement, QueryState queryState, QueryOptions options) throws RequestExecutionException, RequestValidationException diff --git a/sstables/shared_index_lists.hh b/sstables/shared_index_lists.hh index 4d96b7c2be..7d96fabfd7 100644 --- a/sstables/shared_index_lists.hh +++ b/sstables/shared_index_lists.hh @@ -21,10 +21,9 @@ #pragma once -#include #include -#include #include +#include "utils/loading_shared_values.hh" namespace sstables { @@ -36,50 +35,26 @@ using index_list = std::vector; class shared_index_lists { public: using key_type = uint64_t; - struct stats { + static thread_local struct stats { uint64_t hits = 0; // Number of times entry was found ready uint64_t misses = 0; // Number of times entry was not found uint64_t blocks = 0; // Number of times entry was not ready (>= misses) - }; -private: - class entry : public enable_lw_shared_from_this { - public: - key_type key; - index_list list; - shared_promise<> loaded; - shared_index_lists& parent; + } _shard_stats; - entry(shared_index_lists& parent, key_type key) - : key(key), parent(parent) - { } - ~entry() { - parent._lists.erase(key); - } - bool operator==(const entry& e) const { return key == e.key; } - bool operator!=(const entry& e) const { return key != e.key; } + struct stats_updater { + static void inc_hits() noexcept { ++_shard_stats.hits; } + static void inc_misses() noexcept { ++_shard_stats.misses; } + static void inc_blocks() noexcept { ++_shard_stats.blocks; } + static void inc_evictions() noexcept {} }; - std::unordered_map _lists; - static thread_local stats _shard_stats; -public: + + using loading_shared_lists_type = utils::loading_shared_values, std::equal_to, stats_updater>; // Pointer to index_list - class list_ptr { - lw_shared_ptr _e; - public: - using element_type = index_list; - list_ptr() = default; - explicit list_ptr(lw_shared_ptr e) : _e(std::move(e)) {} - explicit operator bool() const { return static_cast(_e); } - index_list& operator*() { return _e->list; } - const index_list& operator*() const { return _e->list; } - index_list* operator->() { return &_e->list; } - const index_list* operator->() const { return &_e->list; } + using list_ptr = loading_shared_lists_type::entry_ptr; +private: - index_list release() { - auto res = _e.owned() ? index_list(std::move(_e->list)) : index_list(_e->list); - _e = {}; - return std::move(res); - } - }; + loading_shared_lists_type _lists; +public: shared_index_lists() = default; shared_index_lists(shared_index_lists&&) = delete; @@ -93,41 +68,8 @@ public: // // The loader object does not survive deferring, so the caller must deal with its liveness. template - future get_or_load(key_type key, Loader&& loader) { - auto i = _lists.find(key); - lw_shared_ptr e; - auto f = [&] { - if (i != _lists.end()) { - e = i->second->shared_from_this(); - return e->loaded.get_shared_future(); - } else { - ++_shard_stats.misses; - e = make_lw_shared(*this, key); - auto f = e->loaded.get_shared_future(); - auto res = _lists.emplace(key, e.get()); - assert(res.second); - futurize_apply(loader, key).then_wrapped([e](future&& f) mutable { - if (f.failed()) { - e->loaded.set_exception(f.get_exception()); - } else { - e->list = f.get0(); - e->loaded.set_value(); - } - }); - return f; - } - }(); - if (!f.available()) { - ++_shard_stats.blocks; - return f.then([e]() mutable { - return list_ptr(std::move(e)); - }); - } else if (f.failed()) { - return make_exception_future(std::move(f).get_exception()); - } else { - ++_shard_stats.hits; - return make_ready_future(list_ptr(std::move(e))); - } + future get_or_load(const key_type& key, Loader&& loader) { + return _lists.get_or_load(key, std::forward(loader)); } static const stats& shard_stats() { return _shard_stats; } diff --git a/test.py b/test.py index db9887aa60..8f8925ced6 100755 --- a/test.py +++ b/test.py @@ -82,6 +82,7 @@ boost_tests = [ 'counter_test', 'cell_locker_test', 'clustering_ranges_walker_test', + 'loading_cache_test', ] other_tests = [ diff --git a/tests/cql_test_env.cc b/tests/cql_test_env.cc index 25a4b1a903..cbb786427c 100644 --- a/tests/cql_test_env.cc +++ b/tests/cql_test_env.cc @@ -120,7 +120,7 @@ public: }); } - virtual future prepare(sstring query) override { + virtual future prepare(sstring query) override { return qp().invoke_on_all([query, this] (auto& local_qp) { auto qs = this->make_query_state(); return local_qp.prepare(query, *qs).finally([qs] {}).discard_result(); @@ -130,7 +130,7 @@ public: } virtual future<::shared_ptr> execute_prepared( - bytes id, + cql3::prepared_cache_key_type id, std::vector values) override { auto prepared = local_qp().get_prepared(id); diff --git a/tests/cql_test_env.hh b/tests/cql_test_env.hh index ffd66d4b66..956a22bfd2 100644 --- a/tests/cql_test_env.hh +++ b/tests/cql_test_env.hh @@ -32,6 +32,7 @@ #include "transport/messages/result_message_base.hh" #include "cql3/query_options_fwd.hh" #include "cql3/values.hh" +#include "cql3/prepared_statements_cache.hh" #include "bytes.hh" #include "schema.hh" @@ -43,7 +44,7 @@ namespace cql3 { class not_prepared_exception : public std::runtime_error { public: - not_prepared_exception(const bytes& id) : std::runtime_error(sprint("Not prepared: %s", id)) {} + not_prepared_exception(const cql3::prepared_cache_key_type& id) : std::runtime_error(sprint("Not prepared: %s", id)) {} }; namespace db { @@ -59,10 +60,10 @@ public: virtual future<::shared_ptr> execute_cql( const sstring& text, std::unique_ptr qo) = 0; - virtual future prepare(sstring query) = 0; + virtual future prepare(sstring query) = 0; virtual future<::shared_ptr> execute_prepared( - bytes id, std::vector values) = 0; + cql3::prepared_cache_key_type id, std::vector values) = 0; virtual future<> create_table(std::function schema_maker) = 0; diff --git a/tests/loading_cache_test.cc b/tests/loading_cache_test.cc new file mode 100644 index 0000000000..0e919db250 --- /dev/null +++ b/tests/loading_cache_test.cc @@ -0,0 +1,321 @@ +/* + * Copyright (C) 2017 ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#include +#include "utils/loading_shared_values.hh" +#include "utils/loading_cache.hh" +#include +#include +#include +#include +#include + + +#include "seastarx.hh" + +#include "tests/test-utils.hh" +#include "tmpdir.hh" +#include "log.hh" + +#include +#include +#include + +/// Get a random integer in the [0, max) range. +/// \param upper bound of the random value range +/// \return The uniformly distributed random integer from the [0, \ref max) range. +static int rand_int(int max) { + std::random_device rd; // only used once to initialise (seed) engine + std::mt19937 rng(rd()); // random-number engine used (Mersenne-Twister in this case) + std::uniform_int_distribution uni(0, max - 1); // guaranteed unbiased + return uni(rng); +} + + +#include "disk-error-handler.hh" + +thread_local disk_error_signal_type general_disk_error; +thread_local disk_error_signal_type commit_error; + +static const sstring test_file_name = "loading_cache_test.txt"; +static const sstring test_string = "1"; +static bool file_prepared = false; +static constexpr int num_loaders = 1000; + +static logging::logger test_logger("loading_cache_test"); + +static thread_local int load_count; +static const tmpdir& get_tmpdir() { + static thread_local tmpdir tmp; + return tmp; +} + +static future<> prepare() { + if (file_prepared) { + return make_ready_future<>(); + } + + return open_file_dma((boost::filesystem::path(get_tmpdir().path) / test_file_name.c_str()).c_str(), open_flags::create | open_flags::wo).then([] (file f) { + return do_with(std::move(f), [] (file& f) { + return f.dma_write(0, test_string.c_str(), test_string.size() + 1).then([] (size_t s) { + BOOST_REQUIRE_EQUAL(s, test_string.size() + 1); + file_prepared = true; + }); + }); + }); +} + +static future loader(const int& k) { + return open_file_dma((boost::filesystem::path(get_tmpdir().path) / test_file_name.c_str()).c_str(), open_flags::ro).then([] (file f) -> future { + return do_with(std::move(f), [] (file& f) -> future { + return f.dma_read_exactly(0, test_string.size() + 1).then([] (auto buf) { + sstring str(buf.get()); + BOOST_REQUIRE_EQUAL(str, test_string); + ++load_count; + return make_ready_future(std::move(str)); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_loading_shared_values_parallel_loading_same_key) { + return seastar::async([] { + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_shared_values shared_values; + std::list::entry_ptr> anchors_list; + + prepare().get(); + + std::fill(ivec.begin(), ivec.end(), 0); + + parallel_for_each(ivec, [&] (int& k) { + return shared_values.get_or_load(k, loader).then([&] (auto entry_ptr) { + anchors_list.emplace_back(std::move(entry_ptr)); + }); + }).get(); + + // "loader" must be called exactly once + BOOST_REQUIRE_EQUAL(load_count, 1); + BOOST_REQUIRE_EQUAL(shared_values.size(), 1); + anchors_list.clear(); + }); +} + +SEASTAR_TEST_CASE(test_loading_shared_values_parallel_loading_different_keys) { + return seastar::async([] { + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_shared_values shared_values; + std::list::entry_ptr> anchors_list; + + prepare().get(); + + std::iota(ivec.begin(), ivec.end(), 0); + + parallel_for_each(ivec, [&] (int& k) { + return shared_values.get_or_load(k, loader).then([&] (auto entry_ptr) { + anchors_list.emplace_back(std::move(entry_ptr)); + }); + }).get(); + + // "loader" must be called once for each key + BOOST_REQUIRE_EQUAL(load_count, num_loaders); + BOOST_REQUIRE_EQUAL(shared_values.size(), num_loaders); + anchors_list.clear(); + }); +} + +SEASTAR_TEST_CASE(test_loading_shared_values_rehash) { + return seastar::async([] { + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_shared_values shared_values; + std::list::entry_ptr> anchors_list; + + prepare().get(); + + std::iota(ivec.begin(), ivec.end(), 0); + + // verify that load factor is always in the (0.25, 0.75) range + for (int k = 0; k < num_loaders; ++k) { + shared_values.get_or_load(k, loader).then([&] (auto entry_ptr) { + anchors_list.emplace_back(std::move(entry_ptr)); + }).get(); + BOOST_REQUIRE_LE(shared_values.size(), 3 * shared_values.buckets_count() / 4); + } + + BOOST_REQUIRE_GE(shared_values.size(), shared_values.buckets_count() / 4); + + // minimum buckets count (by default) is 16, so don't check for less than 4 elements + for (int k = 0; k < num_loaders - 4; ++k) { + anchors_list.pop_back(); + shared_values.rehash(); + BOOST_REQUIRE_GE(shared_values.size(), shared_values.buckets_count() / 4); + } + + anchors_list.clear(); + }); +} + +SEASTAR_TEST_CASE(test_loading_shared_values_parallel_loading_explicit_eviction) { + return seastar::async([] { + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_shared_values shared_values; + std::vector::entry_ptr> anchors_vec(num_loaders); + + prepare().get(); + + std::iota(ivec.begin(), ivec.end(), 0); + + parallel_for_each(ivec, [&] (int& k) { + return shared_values.get_or_load(k, loader).then([&] (auto entry_ptr) { + anchors_vec[k] = std::move(entry_ptr); + }); + }).get(); + + int rand_key = rand_int(num_loaders); + BOOST_REQUIRE(shared_values.find(rand_key) != shared_values.end()); + anchors_vec[rand_key] = nullptr; + BOOST_REQUIRE_MESSAGE(shared_values.find(rand_key) == shared_values.end(), format("explicit removal for key {} failed", rand_key)); + anchors_vec.clear(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_loading_same_key) { + return seastar::async([] { + using namespace std::chrono; + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 1s, test_logger); + + prepare().get(); + + std::fill(ivec.begin(), ivec.end(), 0); + + parallel_for_each(ivec, [&] (int& k) { + return loading_cache.get_ptr(k, loader).discard_result(); + }).get(); + + // "loader" must be called exactly once + BOOST_REQUIRE_EQUAL(load_count, 1); + BOOST_REQUIRE_EQUAL(loading_cache.size(), 1); + loading_cache.stop().get(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_loading_different_keys) { + return seastar::async([] { + using namespace std::chrono; + std::vector ivec(num_loaders); + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 1s, test_logger); + + prepare().get(); + + std::iota(ivec.begin(), ivec.end(), 0); + + parallel_for_each(ivec, [&] (int& k) { + return loading_cache.get_ptr(k, loader).discard_result(); + }).get(); + + BOOST_REQUIRE_EQUAL(load_count, num_loaders); + BOOST_REQUIRE_EQUAL(loading_cache.size(), num_loaders); + loading_cache.stop().get(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_loading_expiry_eviction) { + return seastar::async([] { + using namespace std::chrono; + utils::loading_cache loading_cache(num_loaders, 20ms, test_logger); + + prepare().get(); + + loading_cache.get_ptr(0, loader).discard_result().get(); + + BOOST_REQUIRE(loading_cache.find(0) != loading_cache.end()); + + // timers get delayed sometimes (especially in a debug mode) + constexpr int max_retry = 10; + int i = 0; + do_until( + [&] { return i++ > max_retry || loading_cache.find(0) == loading_cache.end(); }, + [] { return sleep(40ms); } + ).get(); + BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end()); + loading_cache.stop().get(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_loading_reloading) { + return seastar::async([] { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(num_loaders, 100ms, 20ms, test_logger, loader); + prepare().get(); + loading_cache.get_ptr(0, loader).discard_result().get(); + sleep(60ms).get(); + BOOST_REQUIRE_MESSAGE(load_count >= 2, format("load_count is {}", load_count)); + loading_cache.stop().get(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_max_size_eviction) { + return seastar::async([] { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(1, 1s, test_logger); + + prepare().get(); + + for (int i = 0; i < num_loaders; ++i) { + loading_cache.get_ptr(i % 2, loader).discard_result().get(); + } + + BOOST_REQUIRE_EQUAL(load_count, num_loaders); + BOOST_REQUIRE_EQUAL(loading_cache.size(), 1); + loading_cache.stop().get(); + }); +} + +SEASTAR_TEST_CASE(test_loading_cache_reload_during_eviction) { + return seastar::async([] { + using namespace std::chrono; + load_count = 0; + utils::loading_cache loading_cache(1, 100ms, 10ms, test_logger, loader); + + prepare().get(); + + auto curr_time = lowres_clock::now(); + int i = 0; + + // this will cause reloading when values are being actively evicted due to the limited cache size + do_until( + [&] { return lowres_clock::now() - curr_time > 1s; }, + [&] { return loading_cache.get_ptr(i++ % 2).discard_result(); } + ).get(); + + BOOST_REQUIRE_EQUAL(loading_cache.size(), 1); + loading_cache.stop().get(); + }); +} \ No newline at end of file diff --git a/tests/schema_change_test.cc b/tests/schema_change_test.cc index cd1902b141..6d2d129e1c 100644 --- a/tests/schema_change_test.cc +++ b/tests/schema_change_test.cc @@ -408,7 +408,7 @@ SEASTAR_TEST_CASE(test_prepared_statement_is_invalidated_by_schema_change) { logging::logger_registry().set_logger_level("query_processor", logging::log_level::debug); e.execute_cql("create keyspace tests with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };").get(); e.execute_cql("create table tests.table1 (pk int primary key, c1 int, c2 int);").get(); - bytes id = e.prepare("select * from tests.table1;").get0(); + auto id = e.prepare("select * from tests.table1;").get0(); e.execute_cql("alter table tests.table1 add s1 int;").get(); diff --git a/thrift/handler.cc b/thrift/handler.cc index c65b224b4c..1d2f34971c 100644 --- a/thrift/handler.cc +++ b/thrift/handler.cc @@ -1002,7 +1002,7 @@ public: void execute_prepared_cql3_query(tcxx::function cob, tcxx::function exn_cob, const int32_t itemId, const std::vector & values, const ConsistencyLevel::type consistency) { with_exn_cob(std::move(exn_cob), [&] { - auto prepared = _query_processor.local().get_prepared_for_thrift(itemId); + auto prepared = _query_processor.local().get_prepared(cql3::prepared_cache_key_type(itemId)); if (!prepared) { throw make_exception("Prepared query with id %d not found", itemId); } diff --git a/transport/server.cc b/transport/server.cc index 680e596066..17e18aae2e 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -826,15 +826,14 @@ future cql_server::connection::process_prepare(uint16_t stream, b return parallel_for_each(cpus.begin(), cpus.end(), [this, query, cpu_id, &cs] (unsigned int c) mutable { if (c != cpu_id) { return smp::submit_to(c, [this, query, &cs] () mutable { - _server._query_processor.local().prepare(query, cs, false); - // FIXME: error handling + return _server._query_processor.local().prepare(std::move(query), cs, false).discard_result(); }); } else { return make_ready_future<>(); } - }).then([this, query, stream, &cs] { + }).then([this, query, stream, &cs] () mutable { tracing::trace(cs.get_trace_state(), "Done preparing on remote shards"); - return _server._query_processor.local().prepare(query, cs, false).then([this, stream, &cs] (auto msg) { + return _server._query_processor.local().prepare(std::move(query), cs, false).then([this, stream, &cs] (auto msg) { tracing::trace(cs.get_trace_state(), "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] { return messages::result_message::prepared::cql::get_id(msg); })); @@ -848,8 +847,9 @@ future cql_server::connection::process_prepare(uint16_t stream, b future cql_server::connection::process_execute(uint16_t stream, bytes_view buf, service::client_state client_state) { - auto id = read_short_bytes(buf); - auto prepared = _server._query_processor.local().get_prepared(id); + cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); + auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); + auto prepared = _server._query_processor.local().get_prepared(cache_key); if (!prepared) { throw exceptions::prepared_query_not_found_exception(id); } @@ -925,8 +925,9 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service:: break; } case 1: { - auto id = read_short_bytes(buf); - ps = _server._query_processor.local().get_prepared(id); + cql3::prepared_cache_key_type cache_key(read_short_bytes(buf)); + auto& id = cql3::prepared_cache_key_type::cql_id(cache_key); + ps = _server._query_processor.local().get_prepared(cache_key); if (!ps) { throw exceptions::prepared_query_not_found_exception(id); } diff --git a/utils/loading_cache.hh b/utils/loading_cache.hh index b6aa8cc4b7..c3a9e502af 100644 --- a/utils/loading_cache.hh +++ b/utils/loading_cache.hh @@ -29,77 +29,54 @@ #include #include -#include "utils/exceptions.hh" +#include "exceptions/exceptions.hh" +#include "utils/loading_shared_values.hh" +#include "log.hh" namespace bi = boost::intrusive; namespace utils { -// Simple variant of the "LoadingCache" used for permissions in origin. -typedef lowres_clock loading_cache_clock_type; -typedef bi::list_base_hook> auto_unlink_list_hook; +using loading_cache_clock_type = seastar::lowres_clock; +using auto_unlink_list_hook = bi::list_base_hook>; -template -class timestamped_val : public auto_unlink_list_hook, public bi::unordered_set_base_hook> { +template +class timestamped_val { public: - typedef bi::list> lru_list_type; - typedef Key key_type; - typedef Tp value_type; + using value_type = Tp; + using loading_values_type = typename utils::loading_shared_values; + class lru_entry; + class value_ptr; private: - std::experimental::optional _opt_value; + value_type _value; loading_cache_clock_type::time_point _loaded; loading_cache_clock_type::time_point _last_read; - lru_list_type& _lru_list; /// MRU item is at the front, LRU - at the back - Key _key; + lru_entry* _lru_entry_ptr = nullptr; /// MRU item is at the front, LRU - at the back + size_t _size = 0; public: - struct key_eq { - bool operator()(const Key& k, const timestamped_val& c) const { - return EqualPred()(k, c.key()); - } - - bool operator()(const timestamped_val& c, const Key& k) const { - return EqualPred()(c.key(), k); - } - }; - - timestamped_val(lru_list_type& lru_list, const Key& key) - : _loaded(loading_cache_clock_type::now()) + timestamped_val(value_type val) + : _value(std::move(val)) + , _loaded(loading_cache_clock_type::now()) , _last_read(_loaded) - , _lru_list(lru_list) - , _key(key) {} - - timestamped_val(lru_list_type& lru_list, Key&& key) - : _loaded(loading_cache_clock_type::now()) - , _last_read(_loaded) - , _lru_list(lru_list) - , _key(std::move(key)) {} - - timestamped_val(const timestamped_val&) = default; + , _size(EntrySize()(_value)) + {} timestamped_val(timestamped_val&&) = default; - // Make sure copy/move-assignments don't go through the template below - timestamped_val& operator=(const timestamped_val&) = default; - timestamped_val& operator=(timestamped_val&) = default; - timestamped_val& operator=(timestamped_val&&) = default; + timestamped_val& operator=(value_type new_val) { + assert(_lru_entry_ptr); - template - timestamped_val& operator=(U&& new_val) { - _opt_value = std::forward(new_val); + _value = std::move(new_val); _loaded = loading_cache_clock_type::now(); + _lru_entry_ptr->cache_size() -= _size; + _size = EntrySize()(_value); + _lru_entry_ptr->cache_size() += _size; return *this; } - const Tp& value() { - _last_read = loading_cache_clock_type::now(); - touch(); - return _opt_value.value(); - } - - explicit operator bool() const noexcept { - return bool(_opt_value); - } + value_type& value() noexcept { return _value; } + const value_type& value() const noexcept { return _value; } loading_cache_clock_type::time_point last_read() const noexcept { return _last_read; @@ -109,163 +86,353 @@ public: return _loaded; } - const Key& key() const { - return _key; + size_t size() const { + return _size; } - friend bool operator==(const timestamped_val& a, const timestamped_val& b){ - return EqualPred()(a.key(), b.key()); - } - - friend std::size_t hash_value(const timestamped_val& v) { - return Hash()(v.key()); + bool ready() const noexcept { + return _lru_entry_ptr; } private: + void touch() noexcept { + assert(_lru_entry_ptr); + _last_read = loading_cache_clock_type::now(); + _lru_entry_ptr->touch(); + } + + void set_anchor_back_reference(lru_entry* lru_entry_ptr) noexcept { + _lru_entry_ptr = lru_entry_ptr; + } +}; + +template +struct simple_entry_size { + size_t operator()(const Tp& val) { + return 1; + } +}; + +template +class timestamped_val::value_ptr { +private: + using ts_value_type = timestamped_val; + using loading_values_type = typename ts_value_type::loading_values_type; + +public: + using timestamped_val_ptr = typename loading_values_type::entry_ptr; + using value_type = Tp; + +private: + timestamped_val_ptr _ts_val_ptr; + +public: + value_ptr(timestamped_val_ptr ts_val_ptr) : _ts_val_ptr(std::move(ts_val_ptr)) { _ts_val_ptr->touch(); } + explicit operator bool() const noexcept { return bool(_ts_val_ptr); } + value_type& operator*() const noexcept { return _ts_val_ptr->value(); } + value_type* operator->() const noexcept { return &_ts_val_ptr->value(); } +}; + +/// \brief This is and LRU list entry which is also an anchor for a loading_cache value. +template +class timestamped_val::lru_entry : public auto_unlink_list_hook { +private: + using ts_value_type = timestamped_val; + using loading_values_type = typename ts_value_type::loading_values_type; + +public: + using lru_list_type = bi::list>; + using timestamped_val_ptr = typename loading_values_type::entry_ptr; + +private: + timestamped_val_ptr _ts_val_ptr; + lru_list_type& _lru_list; + size_t& _cache_size; + +public: + lru_entry(timestamped_val_ptr ts_val, lru_list_type& lru_list, size_t& cache_size) + : _ts_val_ptr(std::move(ts_val)) + , _lru_list(lru_list) + , _cache_size(cache_size) + { + _ts_val_ptr->set_anchor_back_reference(this); + _cache_size += _ts_val_ptr->size(); + } + + ~lru_entry() { + _cache_size -= _ts_val_ptr->size(); + _ts_val_ptr->set_anchor_back_reference(nullptr); + } + + size_t& cache_size() noexcept { + return _cache_size; + } + /// Set this item as the most recently used item. /// The MRU item is going to be at the front of the _lru_list, the LRU item - at the back. void touch() noexcept { auto_unlink_list_hook::unlink(); _lru_list.push_front(*this); } -}; -class shared_mutex { -private: - lw_shared_ptr _mutex_ptr; - -public: - shared_mutex() : _mutex_ptr(make_lw_shared(1)) {} - semaphore& get() const noexcept { - return *_mutex_ptr; + const Key& key() const noexcept { + return loading_values_type::to_key(_ts_val_ptr); } + + timestamped_val& timestamped_value() noexcept { return *_ts_val_ptr; } + const timestamped_val& timestamped_value() const noexcept { return *_ts_val_ptr; } + timestamped_val_ptr timestamped_value_ptr() noexcept { return _ts_val_ptr; } }; +enum class loading_cache_reload_enabled { no, yes }; + +/// \brief Loading cache is a cache that loads the value into the cache using the given asynchronous callback. +/// +/// Each cached value if reloading is enabled (\tparam ReloadEnabled == loading_cache_reload_enabled::yes) is reloaded after +/// the "refresh" time period since it was loaded for the last time. +/// +/// The values are going to be evicted from the cache if they are not accessed during the "expiration" period or haven't +/// been reloaded even once during the same period. +/// +/// If "expiration" is set to zero - the caching is going to be disabled and get_XXX(...) is going to call the "loader" callback +/// every time in order to get the requested value. +/// +/// \note In order to avoid the eviction of cached entries due to "aging" of the contained value the user has to choose +/// the "expiration" to be at least ("refresh" + "max load latency"). This way the value is going to stay in the cache and is going to be +/// read in a non-blocking way as long as it's frequently accessed. Note however that since reloading is an asynchronous +/// procedure it may get delayed by other running task. Therefore choosing the "expiration" too close to the ("refresh" + "max load latency") +/// value one risks to have his/her cache values evicted when the system is heavily loaded. +/// +/// The cache is also limited in size and if adding the next value is going +/// to exceed the cache size limit the least recently used value(s) is(are) going to be evicted until the size of the cache +/// becomes such that adding the new value is not going to break the size limit. If the new entry's size is greater than +/// the cache size then the get_XXX(...) method is going to return a future with the loading_cache::entry_is_too_big exception. +/// +/// The size of the cache is defined as a sum of sizes of all cached entries. +/// The size of each entry is defined by the value returned by the \tparam EntrySize predicate applied on it. +/// +/// The get(key) or get_ptr(key) methods ensures that the "loader" callback is called only once for each cached entry regardless of how many +/// callers are calling for the get_XXX(key) for the same "key" at the same time. Only after the value is evicted from the cache +/// it's going to be "loaded" in the context of get_XXX(key). As long as the value is cached get_XXX(key) is going to return the +/// cached value immediately and reload it in the background every "refresh" time period as described above. +/// +/// \tparam Key type of the cache key +/// \tparam Tp type of the cached value +/// \tparam ReloadEnabled if loading_cache_reload_enabled::yes allow reloading the values otherwise don't reload +/// \tparam EntrySize predicate to calculate the entry size +/// \tparam Hash hash function +/// \tparam EqualPred equality predicate +/// \tparam LoadingSharedValuesStats statistics incrementing class (see utils::loading_shared_values) +/// \tparam Alloc elements allocator template, typename Hash = std::hash, typename EqualPred = std::equal_to, - typename Alloc = std::allocator>, - typename SharedMutexMapAlloc = std::allocator>> + typename LoadingSharedValuesStats = utils::do_nothing_loading_shared_values_stats, + typename Alloc = std::allocator::lru_entry>> class loading_cache { private: - typedef timestamped_val ts_value_type; - typedef bi::unordered_set, bi::compare_hash> set_type; - typedef std::unordered_map write_mutex_map_type; - typedef typename ts_value_type::lru_list_type lru_list_type; - typedef typename set_type::bucket_traits bi_set_bucket_traits; - - static constexpr int initial_num_buckets = 256; - static constexpr int max_num_buckets = 1024 * 1024; + using ts_value_type = timestamped_val; + using loading_values_type = typename ts_value_type::loading_values_type; + using timestamped_val_ptr = typename loading_values_type::entry_ptr; + using ts_value_lru_entry = typename ts_value_type::lru_entry; + using set_iterator = typename loading_values_type::iterator; + using lru_list_type = typename ts_value_lru_entry::lru_list_type; + struct value_extractor_fn { + Tp& operator()(ts_value_type& tv) const { + return tv.value(); + } + }; public: - typedef Tp value_type; - typedef Key key_type; - typedef typename set_type::iterator iterator; + using value_type = Tp; + using key_type = Key; + using value_ptr = typename ts_value_type::value_ptr; + class entry_is_too_big : public std::exception {}; + using iterator = boost::transform_iterator; + +private: + loading_cache(size_t max_size, std::chrono::milliseconds expiry, std::chrono::milliseconds refresh, logging::logger& logger) + : _max_size(max_size) + , _expiry(expiry) + , _refresh(refresh) + , _logger(logger) + , _timer([this] { on_timer(); }) + { + // Sanity check: if expiration period is given then non-zero refresh period and maximal size are required + if (caching_enabled() && (_refresh == std::chrono::milliseconds(0) || _max_size == 0)) { + throw exceptions::configuration_exception("loading_cache: caching is enabled but refresh period and/or max_size are zero"); + } + } + +public: template loading_cache(size_t max_size, std::chrono::milliseconds expiry, std::chrono::milliseconds refresh, logging::logger& logger, Func&& load) - : _buckets(initial_num_buckets) - , _set(bi_set_bucket_traits(_buckets.data(), _buckets.size())) - , _max_size(max_size) - , _expiry(expiry) - , _refresh(refresh) - , _logger(logger) - , _load(std::forward(load)) { + : loading_cache(max_size, expiry, refresh, logger) + { + static_assert(ReloadEnabled == loading_cache_reload_enabled::yes, "This constructor should only be invoked when ReloadEnabled == loading_cache_reload_enabled::yes"); + static_assert(std::is_same, std::result_of_t>::value, "Bad Func signature"); + + _load = std::forward(load); // If expiration period is zero - caching is disabled if (!caching_enabled()) { return; } - // Sanity check: if expiration period is given then non-zero refresh period and maximal size are required - if (_refresh == std::chrono::milliseconds(0) || _max_size == 0) { - throw exceptions::configuration_exception("loading_cache: caching is enabled but refresh period and/or max_size are zero"); + _timer_period = std::min(_expiry, _refresh); + _timer.arm(_timer_period); + } + + loading_cache(size_t max_size, std::chrono::milliseconds expiry, logging::logger& logger) + : loading_cache(max_size, expiry, loading_cache_clock_type::time_point::max().time_since_epoch(), logger) + { + static_assert(ReloadEnabled == loading_cache_reload_enabled::no, "This constructor should only be invoked when ReloadEnabled == loading_cache_reload_enabled::no"); + + // If expiration period is zero - caching is disabled + if (!caching_enabled()) { + return; } - _timer.set_callback([this] { on_timer(); }); - _timer.arm(_refresh); + _timer_period = _expiry; + _timer.arm(_timer_period); } ~loading_cache() { - _set.clear_and_dispose([] (ts_value_type* ptr) { loading_cache::destroy_ts_value(ptr); }); + _lru_list.erase_and_dispose(_lru_list.begin(), _lru_list.end(), [] (ts_value_lru_entry* ptr) { loading_cache::destroy_ts_value(ptr); }); + } + + template + future get_ptr(const Key& k, LoadFunc&& load) { + static_assert(std::is_same, std::result_of_t>::value, "Bad LoadFunc signature"); + // We shouldn't be here if caching is disabled + assert(caching_enabled()); + + return _loading_values.get_or_load(k, [this, load = std::forward(load)] (const Key& k) mutable { + return load(k).then([this] (value_type val) { + return ts_value_type(std::move(val)); + }); + }).then([this, k] (timestamped_val_ptr ts_val_ptr) { + // check again since it could have already been inserted and initialized + if (!ts_val_ptr->ready()) { + _logger.trace("{}: storing the value for the first time", k); + + if (ts_val_ptr->size() > _max_size) { + return make_exception_future(entry_is_too_big()); + } + + ts_value_lru_entry* new_lru_entry = Alloc().allocate(1); + new(new_lru_entry) ts_value_lru_entry(std::move(ts_val_ptr), _lru_list, _current_size); + + // This will "touch" the entry and add it to the LRU list - we must do this before the shrink() call. + value_ptr vp(new_lru_entry->timestamped_value_ptr()); + + // Remove the least recently used items if map is too big. + shrink(); + + return make_ready_future(std::move(vp)); + } + + return make_ready_future(std::move(ts_val_ptr)); + }); + } + + future get_ptr(const Key& k) { + static_assert(ReloadEnabled == loading_cache_reload_enabled::yes); + return get_ptr(k, _load); } future get(const Key& k) { + static_assert(ReloadEnabled == loading_cache_reload_enabled::yes); + // If caching is disabled - always load in the foreground if (!caching_enabled()) { - return _load(k); + return _load(k).then([] (Tp val) { + return make_ready_future(std::move(val)); + }); } - // If the key is not in the cache yet, then find_or_create() is going to - // create a new uninitialized value in the map. If the value is already - // in the cache (the fast path) simply return the value. Otherwise, take - // the mutex and try to load the value (the slow path). - iterator ts_value_it = find_or_create(k); - if (*ts_value_it) { - return make_ready_future(ts_value_it->value()); - } else { - return slow_load(k); - } + return get_ptr(k).then([] (value_ptr v_ptr) { + return make_ready_future(*v_ptr); + }); } future<> stop() { return _timer_reads_gate.close().finally([this] { _timer.cancel(); }); } + iterator find(const Key& k) noexcept { + return boost::make_transform_iterator(set_find(k), _value_extractor_fn); + } + + iterator end() { + return boost::make_transform_iterator(_loading_values.end(), _value_extractor_fn); + } + + iterator begin() { + return boost::make_transform_iterator(_loading_values.begin(), _value_extractor_fn); + } + + template + void remove_if(Pred&& pred) { + static_assert(std::is_same>::value, "Bad Pred signature"); + + _lru_list.remove_and_dispose_if([this, &pred] (const ts_value_lru_entry& v) { + return pred(v.timestamped_value().value()); + }, [this] (ts_value_lru_entry* p) { + loading_cache::destroy_ts_value(p); + }); + } + + size_t size() const { + return _loading_values.size(); + } + + /// \brief returns the memory size the currently cached entries occupy according to the EntrySize predicate. + size_t memory_footprint() const { + return _current_size; + } + private: + set_iterator set_find(const Key& k) noexcept { + set_iterator it = _loading_values.find(k); + set_iterator end_it = set_end(); + + if (it == end_it || !it->ready()) { + return end_it; + } + return it; + } + + set_iterator set_end() noexcept { + return _loading_values.end(); + } + + set_iterator set_begin() noexcept { + return _loading_values.begin(); + } + bool caching_enabled() const { return _expiry != std::chrono::milliseconds(0); } - /// Look for the entry with the given key. It it doesn't exist - create a new one and add it to the _set. - /// - /// \param k The key to look for - /// - /// \return An iterator to the value with the given key (always dirrerent from _set.end()) - template - iterator find_or_create(KeyType&& k) { - iterator i = _set.find(k, Hash(), typename ts_value_type::key_eq()); - if (i == _set.end()) { - ts_value_type* new_ts_val = Alloc().allocate(1); - new(new_ts_val) ts_value_type(_lru_list, std::forward(k)); - auto p = _set.insert(*new_ts_val); - i = p.first; - } - - return i; - } - - static void destroy_ts_value(ts_value_type* val) { - val->~ts_value_type(); + static void destroy_ts_value(ts_value_lru_entry* val) { + val->~ts_value_lru_entry(); Alloc().deallocate(val, 1); } - future slow_load(const Key& k) { - // If the key is not in the cache yet, then _write_mutex_map[k] is going - // to create a new value with the initialized mutex. The mutex is going - // to serialize the producers and only the first one is going to - // actually issue a load operation and initialize the value with the - // received result. The rest are going to see (and read) the initialized - // value when they enter the critical section. - shared_mutex sm = _write_mutex_map[k]; - return with_semaphore(sm.get(), 1, [this, k] { - iterator ts_value_it = find_or_create(k); - if (*ts_value_it) { - return make_ready_future(ts_value_it->value()); + future<> reload(ts_value_lru_entry& lru_entry) { + return _load(lru_entry.key()).then_wrapped([this, key = lru_entry.key()] (auto&& f) mutable { + // if the entry has been evicted by now - simply end here + set_iterator it = set_find(key); + if (it == set_end()) { + _logger.trace("{}: entry was dropped during the reload", key); + return make_ready_future<>(); } - _logger.trace("{}: storing the value for the first time", k); - return _load(k).then([this, k] (Tp t) { - // we have to "re-read" the _set here because the value may have been evicted by now - iterator ts_value_it = find_or_create(std::move(k)); - *ts_value_it = std::move(t); - return make_ready_future(ts_value_it->value()); - }); - }).finally([sm] {}); - } - future<> reload(ts_value_type& ts_val) { - return _load(ts_val.key()).then_wrapped([this, &ts_val] (auto&& f) { // The exceptions are related to the load operation itself. // We should ignore them for the background reads - if // they persist the value will age and will be reloaded in @@ -273,120 +440,97 @@ private: // will be propagated up to the user and will fail the // corresponding query. try { - ts_val = f.get0(); + *it = f.get0(); } catch (std::exception& e) { - _logger.debug("{}: reload failed: {}", ts_val.key(), e.what()); + _logger.debug("{}: reload failed: {}", key, e.what()); } catch (...) { - _logger.debug("{}: reload failed: unknown error", ts_val.key()); + _logger.debug("{}: reload failed: unknown error", key); } - }); - } - void erase(iterator it) { - _set.erase_and_dispose(it, [] (ts_value_type* ptr) { loading_cache::destroy_ts_value(ptr); }); - // no need to delete the item from _lru_list - it's auto-deleted + return make_ready_future<>(); + }); } void drop_expired() { auto now = loading_cache_clock_type::now(); - _lru_list.remove_and_dispose_if([now, this] (const ts_value_type& v) { + _lru_list.remove_and_dispose_if([now, this] (const ts_value_lru_entry& lru_entry) { using namespace std::chrono; // An entry should be discarded if it hasn't been reloaded for too long or nobody cares about it anymore + const ts_value_type& v = lru_entry.timestamped_value(); auto since_last_read = now - v.last_read(); auto since_loaded = now - v.loaded(); - if (_expiry < since_last_read || _expiry < since_loaded) { - _logger.trace("drop_expired(): {}: dropping the entry: _expiry {}, ms passed since: loaded {} last_read {}", v.key(), _expiry.count(), duration_cast(since_loaded).count(), duration_cast(since_last_read).count()); + if (_expiry < since_last_read || (ReloadEnabled == loading_cache_reload_enabled::yes && _expiry < since_loaded)) { + _logger.trace("drop_expired(): {}: dropping the entry: _expiry {}, ms passed since: loaded {} last_read {}", lru_entry.key(), _expiry.count(), duration_cast(since_loaded).count(), duration_cast(since_last_read).count()); return true; } return false; - }, [this] (ts_value_type* p) { - erase(_set.iterator_to(*p)); + }, [this] (ts_value_lru_entry* p) { + loading_cache::destroy_ts_value(p); }); } // Shrink the cache to the _max_size discarding the least recently used items void shrink() { - if (_set.size() > _max_size) { - auto num_items_to_erase = _set.size() - _max_size; - for (size_t i = 0; i < num_items_to_erase; ++i) { - using namespace std::chrono; - ts_value_type& ts_val = *_lru_list.rbegin(); - _logger.trace("shrink(): {}: dropping the entry: ms since last_read {}", ts_val.key(), duration_cast(loading_cache_clock_type::now() - ts_val.last_read()).count()); - erase(_set.iterator_to(ts_val)); - } + while (_current_size > _max_size) { + using namespace std::chrono; + ts_value_lru_entry& lru_entry = *_lru_list.rbegin(); + _logger.trace("shrink(): {}: dropping the entry: ms since last_read {}", lru_entry.key(), duration_cast(loading_cache_clock_type::now() - lru_entry.timestamped_value().last_read()).count()); + loading_cache::destroy_ts_value(&lru_entry); } } - void rehash() { - size_t new_buckets_count = 0; - - // Don't grow or shrink too fast even if there is a steep drop/growth in the number of elements in the set. - // Exponential growth/backoff should be good enough. - // - // Try to keep the load factor between 0.25 and 1.0. - if (_set.size() < _current_buckets_count / 4) { - new_buckets_count = _current_buckets_count / 4; - } else if (_set.size() > _current_buckets_count) { - new_buckets_count = _current_buckets_count * 2; + // Try to bring the load factors of the _loading_values into a known range. + void periodic_rehash() noexcept { + try { + _loading_values.rehash(); + } catch (...) { + // if rehashing fails - continue with the current buckets array } - - if (new_buckets_count < initial_num_buckets || new_buckets_count > max_num_buckets) { - return; - } - - std::vector new_buckets(new_buckets_count); - _set.rehash(bi_set_bucket_traits(new_buckets.data(), new_buckets.size())); - _logger.trace("rehash(): buckets count changed: {} -> {}", _current_buckets_count, new_buckets_count); - - _buckets.swap(new_buckets); - _current_buckets_count = new_buckets_count; } void on_timer() { _logger.trace("on_timer(): start"); - auto timer_start_tp = loading_cache_clock_type::now(); - - // Clear all cached mutexes - _write_mutex_map.clear(); - // Clean up items that were not touched for the whole _expiry period. drop_expired(); - // Remove the least recently used items if map is too big. - shrink(); - // check if rehashing is needed and do it if it is. - rehash(); + periodic_rehash(); + + if (ReloadEnabled == loading_cache_reload_enabled::no) { + _logger.trace("on_timer(): rearming"); + _timer.arm(loading_cache_clock_type::now() + _timer_period); + return; + } // Reload all those which vlaue needs to be reloaded. - with_gate(_timer_reads_gate, [this, timer_start_tp] { - return parallel_for_each(_set.begin(), _set.end(), [this, curr_time = timer_start_tp] (auto& ts_val) { - _logger.trace("on_timer(): {}: checking the value age", ts_val.key()); - if (ts_val && ts_val.loaded() + _refresh < curr_time) { - _logger.trace("on_timer(): {}: reloading the value", ts_val.key()); - return this->reload(ts_val); + with_gate(_timer_reads_gate, [this] { + return parallel_for_each(_lru_list.begin(), _lru_list.end(), [this] (ts_value_lru_entry& lru_entry) { + _logger.trace("on_timer(): {}: checking the value age", lru_entry.key()); + if (lru_entry.timestamped_value().loaded() + _refresh < loading_cache_clock_type::now()) { + _logger.trace("on_timer(): {}: reloading the value", lru_entry.key()); + return this->reload(lru_entry); } return now(); - }).finally([this, timer_start_tp] { + }).finally([this] { _logger.trace("on_timer(): rearming"); - _timer.arm(timer_start_tp + _refresh); + _timer.arm(loading_cache_clock_type::now() + _timer_period); }); }); } - std::vector _buckets; - size_t _current_buckets_count = initial_num_buckets; - set_type _set; - write_mutex_map_type _write_mutex_map; + loading_values_type _loading_values; lru_list_type _lru_list; - size_t _max_size; + size_t _current_size = 0; + size_t _max_size = 0; std::chrono::milliseconds _expiry; std::chrono::milliseconds _refresh; + loading_cache_clock_type::duration _timer_period; logging::logger& _logger; std::function(const Key&)> _load; timer _timer; seastar::gate _timer_reads_gate; + value_extractor_fn _value_extractor_fn; }; } diff --git a/utils/loading_shared_values.hh b/utils/loading_shared_values.hh index 2f10863da7..bfa6a276ff 100644 --- a/utils/loading_shared_values.hh +++ b/utils/loading_shared_values.hh @@ -137,7 +137,11 @@ private: using set_type = bi::unordered_set, bi::compare_hash>; using bi_set_bucket_traits = typename set_type::bucket_traits; using set_iterator = typename set_type::iterator; - using value_extractor_fn = std::function; + struct value_extractor_fn { + value_type& operator()(entry& e) const { + return e.value(); + } + }; enum class shrinking_is_allowed { no, yes }; public: @@ -186,7 +190,6 @@ public: loading_shared_values() : _buckets(InitialBucketsCount) , _set(bi_set_bucket_traits(_buckets.data(), _buckets.size())) - , _value_extractor_fn([] (entry& e) -> value_type& { return e.value(); }) { static_assert(noexcept(Stats::inc_evictions()), "Stats::inc_evictions must be non-throwing"); static_assert(noexcept(Stats::inc_hits()), "Stats::inc_hits must be non-throwing");