mirror of
https://github.com/scylladb/scylladb.git
synced 2026-06-07 07:23:15 +00:00
Merge 'Introduce authorized_prepared_statements_cache' from Vlad
"
This series introduces a cache of already authenticated prepared statements which
is meant to optimize the prepared statement lookup when authentication is enabled.
This cache allows to perform a single cache lookup per EXECUTE operation as opposed
to at least 2 lookups: one in the prepared statements cache and one in the authentication
cache.
Tests:
- cql_query_test {debug, release}.
- cassandra-stress with authentication enabled and with short eviction timeout.
- Manual (with printouts) checks:
- Tested the eviction due to eviction in the prepared_statements_cache:
- Artificially decreased the prepared_statements_cache size and ran c-s with different keyspaces.
- Verified that the corresponding authorized_prepared_statements_cache entry is evicted and re-populated.
- Tested the BATCH of prepared statements (with dtest infrastructure):
- Verified that for each prepared statement authorized_prepared_statements_cache is updated only once:
- The batch contained a few entries of the same prepared statement.
"
* 'authorized_prepared_statements_cache-v3' of https://github.com/vladzcloudius/scylla:
cql3: use authorized_prepared_statements_cache in the BATCH processing
cql3::statements::batch_statement: introduce a single_statement class
cql3: introduce the authorized_prepared_statements_cache class
loading_shared_values: introduce the templated find() overload
tests: loading_cache_test: add a tests for a loading_cache::remove(key)/remove(iterator)
utils::loading_cache: add remove(key)/remove(iterator) methods
cql3::query_processor: properly stop() prepared_statements_cache object
This commit is contained in:
189
cql3/authorized_prepared_statements_cache.hh
Normal file
189
cql3/authorized_prepared_statements_cache.hh
Normal file
@@ -0,0 +1,189 @@
|
||||
/*
|
||||
* Copyright (C) 2018 ScyllaDB
|
||||
*/
|
||||
|
||||
/*
|
||||
* This file is part of Scylla.
|
||||
*
|
||||
* Scylla is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* Scylla is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with Scylla. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cql3/prepared_statements_cache.hh"
|
||||
|
||||
namespace cql3 {
|
||||
|
||||
struct authorized_prepared_statements_cache_size {
|
||||
size_t operator()(const statements::prepared_statement::checked_weak_ptr& val) {
|
||||
// TODO: improve the size approximation - most of the entry is occupied by the key here.
|
||||
return 100;
|
||||
}
|
||||
};
|
||||
|
||||
class authorized_prepared_statements_cache_key {
|
||||
public:
|
||||
using cache_key_type = std::pair<auth::authenticated_user, typename cql3::prepared_cache_key_type::cache_key_type>;
|
||||
private:
|
||||
cache_key_type _key;
|
||||
|
||||
public:
|
||||
authorized_prepared_statements_cache_key(auth::authenticated_user user, cql3::prepared_cache_key_type prepared_cache_key)
|
||||
: _key(std::move(user), std::move(prepared_cache_key.key())) {}
|
||||
|
||||
cache_key_type& key() { return _key; }
|
||||
|
||||
const cache_key_type& key() const { return _key; }
|
||||
|
||||
bool operator==(const authorized_prepared_statements_cache_key& other) const {
|
||||
return _key == other._key;
|
||||
}
|
||||
|
||||
bool operator!=(const authorized_prepared_statements_cache_key& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
static size_t hash(const auth::authenticated_user& user, const cql3::prepared_cache_key_type::cache_key_type& prep_cache_key) {
|
||||
return utils::hash_combine(std::hash<auth::authenticated_user>()(user), utils::tuple_hash()(prep_cache_key));
|
||||
}
|
||||
};
|
||||
|
||||
/// \class authorized_prepared_statements_cache
|
||||
/// \brief A cache of previously authorized statements.
|
||||
///
|
||||
/// Entries are inserted every time a new statement is authorized.
|
||||
/// Entries are evicted in any of the following cases:
|
||||
/// - When the corresponding prepared statement is not valid anymore.
|
||||
/// - Periodically, with the same period as the permission cache is refreshed.
|
||||
/// - If the corresponding entry hasn't been used for \ref entry_expiry.
|
||||
class authorized_prepared_statements_cache {
|
||||
public:
|
||||
struct stats {
|
||||
uint64_t authorized_prepared_statements_cache_evictions = 0;
|
||||
};
|
||||
|
||||
static stats& shard_stats() {
|
||||
static thread_local stats _stats;
|
||||
return _stats;
|
||||
}
|
||||
|
||||
struct authorized_prepared_statements_cache_stats_updater {
|
||||
static void inc_hits() noexcept {}
|
||||
static void inc_misses() noexcept {}
|
||||
static void inc_blocks() noexcept {}
|
||||
static void inc_evictions() noexcept {
|
||||
++shard_stats().authorized_prepared_statements_cache_evictions;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
using cache_key_type = authorized_prepared_statements_cache_key;
|
||||
using checked_weak_ptr = typename statements::prepared_statement::checked_weak_ptr;
|
||||
using cache_type = utils::loading_cache<cache_key_type,
|
||||
checked_weak_ptr,
|
||||
utils::loading_cache_reload_enabled::yes,
|
||||
authorized_prepared_statements_cache_size,
|
||||
std::hash<cache_key_type>,
|
||||
std::equal_to<cache_key_type>,
|
||||
authorized_prepared_statements_cache_stats_updater>;
|
||||
|
||||
static const std::chrono::minutes entry_expiry;
|
||||
|
||||
public:
|
||||
using key_type = cache_key_type;
|
||||
using value_type = checked_weak_ptr;
|
||||
using entry_is_too_big = typename cache_type::entry_is_too_big;
|
||||
using iterator = typename cache_type::iterator;
|
||||
|
||||
private:
|
||||
cache_type _cache;
|
||||
logging::logger& _logger;
|
||||
|
||||
public:
|
||||
// Choose the memory budget such that would allow us ~4K entries when a shard gets 1GB of RAM
|
||||
authorized_prepared_statements_cache(std::chrono::milliseconds entry_refresh, logging::logger& logger)
|
||||
: _cache(memory::stats().total_memory() / 2560, entry_expiry, entry_refresh, logger, [this] (const key_type& k) {
|
||||
_cache.remove(k);
|
||||
return make_ready_future<value_type>();
|
||||
})
|
||||
, _logger(logger)
|
||||
{}
|
||||
|
||||
future<> insert(auth::authenticated_user user, cql3::prepared_cache_key_type prep_cache_key, value_type v) noexcept {
|
||||
return _cache.get_ptr(key_type(std::move(user), std::move(prep_cache_key)), [v = std::move(v)] (const cache_key_type&) mutable {
|
||||
return make_ready_future<value_type>(std::move(v));
|
||||
}).discard_result();
|
||||
}
|
||||
|
||||
iterator find(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) {
|
||||
struct key_view {
|
||||
const auth::authenticated_user& user_ref;
|
||||
const cql3::prepared_cache_key_type& prep_cache_key_ref;
|
||||
};
|
||||
|
||||
struct hasher {
|
||||
size_t operator()(const key_view& kv) {
|
||||
return cql3::authorized_prepared_statements_cache_key::hash(kv.user_ref, kv.prep_cache_key_ref.key());
|
||||
}
|
||||
};
|
||||
|
||||
struct equal {
|
||||
bool operator()(const key_type& k1, const key_view& k2) {
|
||||
return k1.key().first == k2.user_ref && k1.key().second == k2.prep_cache_key_ref.key();
|
||||
}
|
||||
|
||||
bool operator()(const key_view& k2, const key_type& k1) {
|
||||
return operator()(k1, k2);
|
||||
}
|
||||
};
|
||||
|
||||
return _cache.find(key_view{user, prep_cache_key}, hasher(), equal());
|
||||
}
|
||||
|
||||
iterator end() {
|
||||
return _cache.end();
|
||||
}
|
||||
|
||||
void remove(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) {
|
||||
iterator it = find(user, prep_cache_key);
|
||||
_cache.remove(it);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return _cache.size();
|
||||
}
|
||||
|
||||
size_t memory_footprint() const {
|
||||
return _cache.memory_footprint();
|
||||
}
|
||||
|
||||
future<> stop() {
|
||||
return _cache.stop();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<cql3::authorized_prepared_statements_cache_key> final {
|
||||
size_t operator()(const cql3::authorized_prepared_statements_cache_key& k) const {
|
||||
return cql3::authorized_prepared_statements_cache_key::hash(k.key().first, k.key().second);
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const cql3::authorized_prepared_statements_cache_key& k) {
|
||||
return out << "{ " << k.key().first << ", " << k.key().second << " }";
|
||||
}
|
||||
}
|
||||
@@ -68,6 +68,14 @@ public:
|
||||
static thrift_prepared_id_type thrift_id(const prepared_cache_key_type& key) {
|
||||
return key.key().second;
|
||||
}
|
||||
|
||||
bool operator==(const prepared_cache_key_type& other) const {
|
||||
return _key == other._key;
|
||||
}
|
||||
|
||||
bool operator!=(const prepared_cache_key_type& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class prepared_statements_cache {
|
||||
@@ -155,6 +163,10 @@ public:
|
||||
size_t memory_footprint() const {
|
||||
return _cache.memory_footprint();
|
||||
}
|
||||
|
||||
future<> stop() {
|
||||
return _cache.stop();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -168,4 +180,11 @@ inline std::ostream& operator<<(std::ostream& os, const cql3::prepared_cache_key
|
||||
os << p.key();
|
||||
return os;
|
||||
}
|
||||
|
||||
template<>
|
||||
struct hash<cql3::prepared_cache_key_type> final {
|
||||
size_t operator()(const cql3::prepared_cache_key_type& k) const {
|
||||
return utils::tuple_hash()(k.key());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -58,12 +58,14 @@ using namespace cql_transport::messages;
|
||||
|
||||
logging::logger log("query_processor");
|
||||
logging::logger prep_cache_log("prepared_statements_cache");
|
||||
logging::logger authorized_prepared_statements_cache_log("authorized_prepared_statements_cache");
|
||||
|
||||
distributed<query_processor> _the_query_processor;
|
||||
|
||||
const sstring query_processor::CQL_VERSION = "3.3.1";
|
||||
|
||||
const std::chrono::minutes prepared_statements_cache::entry_expiry = std::chrono::minutes(60);
|
||||
const std::chrono::minutes authorized_prepared_statements_cache::entry_expiry = std::chrono::minutes(60);
|
||||
|
||||
class query_processor::internal_state {
|
||||
service::query_state _qs;
|
||||
@@ -96,7 +98,8 @@ query_processor::query_processor(service::storage_proxy& proxy, distributed<data
|
||||
, _proxy(proxy)
|
||||
, _db(db)
|
||||
, _internal_state(new internal_state())
|
||||
, _prepared_cache(prep_cache_log) {
|
||||
, _prepared_cache(prep_cache_log)
|
||||
, _authorized_prepared_cache(std::chrono::milliseconds(_db.local().get_config().permissions_update_interval_in_ms()), authorized_prepared_statements_cache_log) {
|
||||
namespace sm = seastar::metrics;
|
||||
|
||||
_metrics.add_group(
|
||||
@@ -199,7 +202,23 @@ query_processor::query_processor(service::storage_proxy& proxy, distributed<data
|
||||
sm::make_derive(
|
||||
"secondary_index_rows_read",
|
||||
_cql_stats.secondary_index_rows_read,
|
||||
sm::description("Counts a total number of rows read during CQL requests performed using secondary indexes."))
|
||||
sm::description("Counts a total number of rows read during CQL requests performed using secondary indexes.")),
|
||||
|
||||
sm::make_derive(
|
||||
"authorized_prepared_statements_cache_evictions",
|
||||
[] { return authorized_prepared_statements_cache::shard_stats().authorized_prepared_statements_cache_evictions; },
|
||||
sm::description("Counts a number of authenticated prepared statements cache entries evictions.")),
|
||||
|
||||
sm::make_gauge(
|
||||
"authorized_prepared_statements_cache_size",
|
||||
[this] { return _authorized_prepared_cache.size(); },
|
||||
sm::description("A number of entries in the authenticated prepared statements cache.")),
|
||||
|
||||
sm::make_gauge(
|
||||
"user_prepared_auth_cache_footprint",
|
||||
[this] { return _authorized_prepared_cache.memory_footprint(); },
|
||||
sm::description("Size (in bytes) of the authenticated prepared statements cache."))
|
||||
|
||||
});
|
||||
|
||||
service::get_local_migration_manager().register_listener(_migration_subscriber.get());
|
||||
@@ -210,7 +229,7 @@ query_processor::~query_processor() {
|
||||
|
||||
future<> query_processor::stop() {
|
||||
service::get_local_migration_manager().unregister_listener(_migration_subscriber.get());
|
||||
return make_ready_future<>();
|
||||
return _authorized_prepared_cache.stop().finally([this] { return _prepared_cache.stop(); });
|
||||
}
|
||||
|
||||
future<::shared_ptr<result_message>>
|
||||
@@ -230,33 +249,60 @@ query_processor::process(const sstring_view& query_string, service::query_state&
|
||||
metrics.regularStatementsExecuted.inc();
|
||||
#endif
|
||||
tracing::trace(query_state.get_trace_state(), "Processing a statement");
|
||||
return process_statement(std::move(cql_statement), query_state, options);
|
||||
return process_statement_unprepared(std::move(cql_statement), query_state, options);
|
||||
}
|
||||
|
||||
future<::shared_ptr<result_message>>
|
||||
query_processor::process_statement(
|
||||
query_processor::process_statement_unprepared(
|
||||
::shared_ptr<cql_statement> statement,
|
||||
service::query_state& query_state,
|
||||
const query_options& options) {
|
||||
return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options]() {
|
||||
auto& client_state = query_state.get_client_state();
|
||||
return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options] () mutable {
|
||||
return process_authorized_statement(std::move(statement), query_state, options);
|
||||
});
|
||||
}
|
||||
|
||||
statement->validate(_proxy, client_state);
|
||||
future<::shared_ptr<result_message>>
|
||||
query_processor::process_statement_prepared(
|
||||
statements::prepared_statement::checked_weak_ptr prepared,
|
||||
cql3::prepared_cache_key_type cache_key,
|
||||
service::query_state& query_state,
|
||||
const query_options& options,
|
||||
bool needs_authorization) {
|
||||
|
||||
auto fut = make_ready_future<::shared_ptr<cql_transport::messages::result_message>>();
|
||||
if (client_state.is_internal()) {
|
||||
fut = statement->execute_internal(_proxy, query_state, options);
|
||||
} else {
|
||||
fut = statement->execute(_proxy, query_state, options);
|
||||
}
|
||||
|
||||
return fut.then([statement] (auto msg) {
|
||||
if (msg) {
|
||||
return make_ready_future<::shared_ptr<result_message>>(std::move(msg));
|
||||
}
|
||||
return make_ready_future<::shared_ptr<result_message>>(
|
||||
::make_shared<result_message::void_message>());
|
||||
::shared_ptr<cql_statement> statement = prepared->statement;
|
||||
future<> fut = make_ready_future<>();
|
||||
if (needs_authorization) {
|
||||
fut = statement->check_access(query_state.get_client_state()).then([this, &query_state, prepared = std::move(prepared), cache_key = std::move(cache_key)] () mutable {
|
||||
return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), std::move(cache_key), std::move(prepared)).handle_exception([this] (auto eptr) {
|
||||
log.error("failed to cache the entry", eptr);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return fut.then([this, statement = std::move(statement), &query_state, &options] () mutable {
|
||||
return process_authorized_statement(std::move(statement), query_state, options);
|
||||
});
|
||||
}
|
||||
|
||||
future<::shared_ptr<result_message>>
|
||||
query_processor::process_authorized_statement(const ::shared_ptr<cql_statement> statement, service::query_state& query_state, const query_options& options) {
|
||||
auto& client_state = query_state.get_client_state();
|
||||
|
||||
statement->validate(_proxy, client_state);
|
||||
|
||||
auto fut = make_ready_future<::shared_ptr<cql_transport::messages::result_message>>();
|
||||
if (client_state.is_internal()) {
|
||||
fut = statement->execute_internal(_proxy, query_state, options);
|
||||
} else {
|
||||
fut = statement->execute(_proxy, query_state, options);
|
||||
}
|
||||
|
||||
return fut.then([statement] (auto msg) {
|
||||
if (msg) {
|
||||
return make_ready_future<::shared_ptr<result_message>>(std::move(msg));
|
||||
}
|
||||
return make_ready_future<::shared_ptr<result_message>>(::make_shared<result_message::void_message>());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -558,11 +604,18 @@ future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
query_processor::process_batch(
|
||||
::shared_ptr<statements::batch_statement> batch,
|
||||
service::query_state& query_state,
|
||||
query_options& options) {
|
||||
return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch] {
|
||||
batch->validate();
|
||||
batch->validate(_proxy, query_state.get_client_state());
|
||||
return batch->execute(_proxy, query_state, options);
|
||||
query_options& options,
|
||||
std::unordered_map<prepared_cache_key_type, authorized_prepared_statements_cache::value_type> pending_authorization_entries) {
|
||||
return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch, pending_authorization_entries = std::move(pending_authorization_entries)] () mutable {
|
||||
return parallel_for_each(pending_authorization_entries, [this, &query_state] (auto& e) {
|
||||
return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), e.first, std::move(e.second)).handle_exception([this] (auto eptr) {
|
||||
log.error("failed to cache the entry", eptr);
|
||||
});
|
||||
}).then([this, &query_state, &options, batch] {
|
||||
batch->validate();
|
||||
batch->validate(_proxy, query_state.get_client_state());
|
||||
return batch->execute(_proxy, query_state, options);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@
|
||||
#include <seastar/core/shared_ptr.hh>
|
||||
|
||||
#include "cql3/prepared_statements_cache.hh"
|
||||
#include "cql3/authorized_prepared_statements_cache.hh"
|
||||
#include "cql3/query_options.hh"
|
||||
#include "cql3/statements/prepared_statement.hh"
|
||||
#include "cql3/statements/raw/parsed_statement.hh"
|
||||
@@ -117,6 +118,7 @@ private:
|
||||
std::unique_ptr<internal_state> _internal_state;
|
||||
|
||||
prepared_statements_cache _prepared_cache;
|
||||
authorized_prepared_statements_cache _authorized_prepared_cache;
|
||||
|
||||
// A map for prepared statements used internally (which we don't want to mix with user statement, in particular we
|
||||
// don't bother with expiration on those.
|
||||
@@ -151,6 +153,21 @@ public:
|
||||
return _cql_stats;
|
||||
}
|
||||
|
||||
statements::prepared_statement::checked_weak_ptr get_prepared(const auth::authenticated_user* user_ptr, const prepared_cache_key_type& key) {
|
||||
if (user_ptr) {
|
||||
auto it = _authorized_prepared_cache.find(*user_ptr, key);
|
||||
if (it != _authorized_prepared_cache.end()) {
|
||||
try {
|
||||
return it->get()->checked_weak_from_this();
|
||||
} catch (seastar::checked_ptr_is_null_exception&) {
|
||||
// If the prepared statement got invalidated - remove the corresponding authorized_prepared_statements_cache entry as well.
|
||||
_authorized_prepared_cache.remove(*user_ptr, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
return statements::prepared_statement::checked_weak_ptr();
|
||||
}
|
||||
|
||||
statements::prepared_statement::checked_weak_ptr get_prepared(const prepared_cache_key_type& key) {
|
||||
auto it = _prepared_cache.find(key);
|
||||
if (it == _prepared_cache.end()) {
|
||||
@@ -160,11 +177,19 @@ public:
|
||||
}
|
||||
|
||||
future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
process_statement(
|
||||
process_statement_unprepared(
|
||||
::shared_ptr<cql_statement> statement,
|
||||
service::query_state& query_state,
|
||||
const query_options& options);
|
||||
|
||||
future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
process_statement_prepared(
|
||||
statements::prepared_statement::checked_weak_ptr statement,
|
||||
cql3::prepared_cache_key_type cache_key,
|
||||
service::query_state& query_state,
|
||||
const query_options& options,
|
||||
bool needs_authorization);
|
||||
|
||||
future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
process(
|
||||
const std::experimental::string_view& query_string,
|
||||
@@ -242,7 +267,11 @@ public:
|
||||
future<> stop();
|
||||
|
||||
future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
process_batch(::shared_ptr<statements::batch_statement>, service::query_state& query_state, query_options& options);
|
||||
process_batch(
|
||||
::shared_ptr<statements::batch_statement>,
|
||||
service::query_state& query_state,
|
||||
query_options& options,
|
||||
std::unordered_map<prepared_cache_key_type, authorized_prepared_statements_cache::value_type> pending_authorization_entries);
|
||||
|
||||
std::unique_ptr<statements::prepared_statement> get_statement(
|
||||
const std::experimental::string_view& query,
|
||||
@@ -257,6 +286,9 @@ private:
|
||||
db::consistency_level = db::consistency_level::ONE,
|
||||
int32_t page_size = -1);
|
||||
|
||||
future<::shared_ptr<cql_transport::messages::result_message>>
|
||||
process_authorized_statement(const ::shared_ptr<cql_statement> statement, service::query_state& query_state, const query_options& options);
|
||||
|
||||
/*!
|
||||
* \brief created a state object for paging
|
||||
*
|
||||
|
||||
@@ -75,19 +75,19 @@ timeout_for_type(batch_statement::type t) {
|
||||
}
|
||||
|
||||
batch_statement::batch_statement(int bound_terms, type type_,
|
||||
std::vector<shared_ptr<modification_statement>> statements,
|
||||
std::vector<single_statement> statements,
|
||||
std::unique_ptr<attributes> attrs,
|
||||
cql_stats& stats)
|
||||
: cql_statement_no_metadata(timeout_for_type(type_))
|
||||
, _bound_terms(bound_terms), _type(type_), _statements(std::move(statements))
|
||||
, _attrs(std::move(attrs))
|
||||
, _has_conditions(boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::has_conditions)))
|
||||
, _has_conditions(boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->has_conditions(); }))
|
||||
, _stats(stats)
|
||||
{
|
||||
}
|
||||
|
||||
batch_statement::batch_statement(type type_,
|
||||
std::vector<shared_ptr<modification_statement>> statements,
|
||||
std::vector<single_statement> statements,
|
||||
std::unique_ptr<attributes> attrs,
|
||||
cql_stats& stats)
|
||||
: batch_statement(-1, type_, std::move(statements), std::move(attrs), stats)
|
||||
@@ -97,7 +97,7 @@ batch_statement::batch_statement(type type_,
|
||||
bool batch_statement::uses_function(const sstring& ks_name, const sstring& function_name) const
|
||||
{
|
||||
return _attrs->uses_function(ks_name, function_name)
|
||||
|| boost::algorithm::any_of(_statements, [&] (auto&& s) { return s->uses_function(ks_name, function_name); });
|
||||
|| boost::algorithm::any_of(_statements, [&] (auto&& s) { return s.statement->uses_function(ks_name, function_name); });
|
||||
}
|
||||
|
||||
bool batch_statement::depends_on_keyspace(const sstring& ks_name) const
|
||||
@@ -118,7 +118,11 @@ uint32_t batch_statement::get_bound_terms()
|
||||
future<> batch_statement::check_access(const service::client_state& state)
|
||||
{
|
||||
return parallel_for_each(_statements.begin(), _statements.end(), [&state](auto&& s) {
|
||||
return s->check_access(state);
|
||||
if (s.needs_authorization) {
|
||||
return s.statement->check_access(state);
|
||||
} else {
|
||||
return make_ready_future<>();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -138,12 +142,12 @@ void batch_statement::validate()
|
||||
}
|
||||
}
|
||||
|
||||
bool has_counters = boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_counter));
|
||||
bool has_non_counters = !boost::algorithm::all_of(_statements, std::mem_fn(&modification_statement::is_counter));
|
||||
bool has_counters = boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_counter(); });
|
||||
bool has_non_counters = !boost::algorithm::all_of(_statements, [] (auto&& s) { return s.statement->is_counter(); });
|
||||
if (timestamp_set && has_counters) {
|
||||
throw exceptions::invalid_request_exception("Cannot provide custom timestamp for a BATCH containing counters");
|
||||
}
|
||||
if (timestamp_set && boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_timestamp_set))) {
|
||||
if (timestamp_set && boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_timestamp_set(); })) {
|
||||
throw exceptions::invalid_request_exception("Timestamp must be set either on BATCH or individual statements");
|
||||
}
|
||||
if (_type == type::COUNTER && has_non_counters) {
|
||||
@@ -159,30 +163,30 @@ void batch_statement::validate()
|
||||
if (_has_conditions
|
||||
&& !_statements.empty()
|
||||
&& (boost::distance(_statements
|
||||
| boost::adaptors::transformed(std::mem_fn(&modification_statement::keyspace))
|
||||
| boost::adaptors::transformed([] (auto&& s) { return s.statement->keyspace(); })
|
||||
| boost::adaptors::uniqued) != 1
|
||||
|| (boost::distance(_statements
|
||||
| boost::adaptors::transformed(std::mem_fn(&modification_statement::column_family))
|
||||
| boost::adaptors::transformed([] (auto&& s) { return s.statement->column_family(); })
|
||||
| boost::adaptors::uniqued) != 1))) {
|
||||
throw exceptions::invalid_request_exception("Batch with conditions cannot span multiple tables");
|
||||
}
|
||||
std::experimental::optional<bool> raw_counter;
|
||||
for (auto& s : _statements) {
|
||||
if (raw_counter && s->is_raw_counter_shard_write() != *raw_counter) {
|
||||
if (raw_counter && s.statement->is_raw_counter_shard_write() != *raw_counter) {
|
||||
throw exceptions::invalid_request_exception("Cannot mix raw and regular counter statements in batch");
|
||||
}
|
||||
raw_counter = s->is_raw_counter_shard_write();
|
||||
raw_counter = s.statement->is_raw_counter_shard_write();
|
||||
}
|
||||
}
|
||||
|
||||
void batch_statement::validate(service::storage_proxy& proxy, const service::client_state& state)
|
||||
{
|
||||
for (auto&& s : _statements) {
|
||||
s->validate(proxy, state);
|
||||
s.statement->validate(proxy, state);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<shared_ptr<modification_statement>>& batch_statement::get_statements()
|
||||
const std::vector<batch_statement::single_statement>& batch_statement::get_statements()
|
||||
{
|
||||
return _statements;
|
||||
}
|
||||
@@ -196,7 +200,7 @@ future<std::vector<mutation>> batch_statement::get_mutations(service::storage_pr
|
||||
return do_for_each(boost::make_counting_iterator<size_t>(0),
|
||||
boost::make_counting_iterator<size_t>(_statements.size()),
|
||||
[this, &storage, &options, now, local, &result, trace_state] (size_t i) {
|
||||
auto&& statement = _statements[i];
|
||||
auto&& statement = _statements[i].statement;
|
||||
statement->inc_cql_stats();
|
||||
auto&& statement_options = options.for_statement(i);
|
||||
auto timestamp = _attrs->get_timestamp(now, statement_options);
|
||||
@@ -426,7 +430,9 @@ batch_statement::prepare(database& db, cql_stats& stats) {
|
||||
stdx::optional<sstring> first_cf;
|
||||
bool have_multiple_cfs = false;
|
||||
|
||||
std::vector<shared_ptr<cql3::statements::modification_statement>> statements;
|
||||
std::vector<cql3::statements::batch_statement::single_statement> statements;
|
||||
statements.reserve(_parsed_statements.size());
|
||||
|
||||
for (auto&& parsed : _parsed_statements) {
|
||||
if (!first_ks) {
|
||||
first_ks = parsed->keyspace();
|
||||
@@ -434,7 +440,7 @@ batch_statement::prepare(database& db, cql_stats& stats) {
|
||||
} else {
|
||||
have_multiple_cfs = first_ks.value() != parsed->keyspace() || first_cf.value() != parsed->column_family();
|
||||
}
|
||||
statements.push_back(parsed->prepare(db, bound_names, stats));
|
||||
statements.emplace_back(parsed->prepare(db, bound_names, stats));
|
||||
}
|
||||
|
||||
auto&& prep_attrs = _attrs->prepare(db, "[batch]", "[batch]");
|
||||
@@ -445,7 +451,7 @@ batch_statement::prepare(database& db, cql_stats& stats) {
|
||||
|
||||
std::vector<uint16_t> partition_key_bind_indices;
|
||||
if (!have_multiple_cfs && batch_statement_.get_statements().size() > 0) {
|
||||
partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0]->s);
|
||||
partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0].statement->s);
|
||||
}
|
||||
return std::make_unique<prepared>(make_shared(std::move(batch_statement_)),
|
||||
bound_names->get_specifications(),
|
||||
|
||||
@@ -66,10 +66,24 @@ class batch_statement : public cql_statement_no_metadata {
|
||||
static logging::logger _logger;
|
||||
public:
|
||||
using type = raw::batch_statement::type;
|
||||
|
||||
struct single_statement {
|
||||
shared_ptr<modification_statement> statement;
|
||||
bool needs_authorization = true;
|
||||
|
||||
public:
|
||||
single_statement(shared_ptr<modification_statement> s)
|
||||
: statement(std::move(s))
|
||||
{}
|
||||
single_statement(shared_ptr<modification_statement> s, bool na)
|
||||
: statement(std::move(s))
|
||||
, needs_authorization(na)
|
||||
{}
|
||||
};
|
||||
private:
|
||||
int _bound_terms;
|
||||
type _type;
|
||||
std::vector<shared_ptr<modification_statement>> _statements;
|
||||
std::vector<single_statement> _statements;
|
||||
std::unique_ptr<attributes> _attrs;
|
||||
bool _has_conditions;
|
||||
cql_stats& _stats;
|
||||
@@ -83,12 +97,12 @@ public:
|
||||
* @param attrs additional attributes for statement (CL, timestamp, timeToLive)
|
||||
*/
|
||||
batch_statement(int bound_terms, type type_,
|
||||
std::vector<shared_ptr<modification_statement>> statements,
|
||||
std::vector<single_statement> statements,
|
||||
std::unique_ptr<attributes> attrs,
|
||||
cql_stats& stats);
|
||||
|
||||
batch_statement(type type_,
|
||||
std::vector<shared_ptr<modification_statement>> statements,
|
||||
std::vector<single_statement> statements,
|
||||
std::unique_ptr<attributes> attrs,
|
||||
cql_stats& stats);
|
||||
|
||||
@@ -109,7 +123,7 @@ public:
|
||||
// or in QueryProcessor.processBatch() - for native protocol batches.
|
||||
virtual void validate(service::storage_proxy& proxy, const service::client_state& state) override;
|
||||
|
||||
const std::vector<shared_ptr<modification_statement>>& get_statements();
|
||||
const std::vector<single_statement>& get_statements();
|
||||
private:
|
||||
future<std::vector<mutation>> get_mutations(service::storage_proxy& storage, const query_options& options, bool local, api::timestamp_type now, tracing::trace_state_ptr trace_state);
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ public:
|
||||
options->prepare(prepared->bound_names);
|
||||
|
||||
auto qs = make_query_state();
|
||||
return local_qp().process_statement(stmt, *qs, *options)
|
||||
return local_qp().process_statement_prepared(std::move(prepared), std::move(id), *qs, *options, true)
|
||||
.finally([options, qs, this] {
|
||||
_core_local.local().client_state.merge(qs->get_client_state());
|
||||
});
|
||||
|
||||
@@ -229,6 +229,41 @@ SEASTAR_TEST_CASE(test_loading_cache_loading_same_key) {
|
||||
});
|
||||
}
|
||||
|
||||
SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_key) {
|
||||
using namespace std::chrono;
|
||||
load_count = 0;
|
||||
utils::loading_cache<int, sstring> loading_cache(num_loaders, 100s, test_logger);
|
||||
auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); });
|
||||
|
||||
prepare().get();
|
||||
|
||||
loading_cache.get_ptr(0, loader).discard_result().get();
|
||||
BOOST_REQUIRE_EQUAL(load_count, 1);
|
||||
BOOST_REQUIRE(loading_cache.find(0) != loading_cache.end());
|
||||
|
||||
loading_cache.remove(0);
|
||||
BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end());
|
||||
}
|
||||
|
||||
SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_iterator) {
|
||||
using namespace std::chrono;
|
||||
load_count = 0;
|
||||
utils::loading_cache<int, sstring> loading_cache(num_loaders, 100s, test_logger);
|
||||
auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); });
|
||||
|
||||
prepare().get();
|
||||
|
||||
loading_cache.get_ptr(0, loader).discard_result().get();
|
||||
BOOST_REQUIRE_EQUAL(load_count, 1);
|
||||
|
||||
auto it = loading_cache.find(0);
|
||||
|
||||
BOOST_REQUIRE(it != loading_cache.end());
|
||||
|
||||
loading_cache.remove(it);
|
||||
BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end());
|
||||
}
|
||||
|
||||
SEASTAR_TEST_CASE(test_loading_cache_loading_different_keys) {
|
||||
return seastar::async([] {
|
||||
using namespace std::chrono;
|
||||
|
||||
@@ -999,9 +999,17 @@ public:
|
||||
|
||||
void execute_prepared_cql3_query(tcxx::function<void(CqlResult const& _return)> cob, tcxx::function<void(::apache::thrift::TDelayedException* _throw)> exn_cob, const int32_t itemId, const std::vector<std::string> & values, const ConsistencyLevel::type consistency) {
|
||||
with_exn_cob(std::move(exn_cob), [&] {
|
||||
auto prepared = _query_processor.local().get_prepared(cql3::prepared_cache_key_type(itemId));
|
||||
cql3::prepared_cache_key_type cache_key(itemId);
|
||||
bool needs_authorization = false;
|
||||
|
||||
auto prepared = _query_processor.local().get_prepared(_query_state.get_client_state().user().get(), cache_key);
|
||||
if (!prepared) {
|
||||
throw make_exception<InvalidRequestException>("Prepared query with id %d not found", itemId);
|
||||
needs_authorization = true;
|
||||
|
||||
prepared = _query_processor.local().get_prepared(cache_key);
|
||||
if (!prepared) {
|
||||
throw make_exception<InvalidRequestException>("Prepared query with id %d not found", itemId);
|
||||
}
|
||||
}
|
||||
auto stmt = prepared->statement;
|
||||
if (stmt->get_bound_terms() != values.size()) {
|
||||
@@ -1013,7 +1021,7 @@ public:
|
||||
});
|
||||
auto opts = std::make_unique<cql3::query_options>(cl_from_thrift(consistency), _timeout_config, stdx::nullopt, std::move(bytes_values),
|
||||
false, cql3::query_options::specific_options::DEFAULT, cql_serialization_format::latest());
|
||||
auto f = _query_processor.local().process_statement(stmt, _query_state, *opts);
|
||||
auto f = _query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), _query_state, *opts, needs_authorization);
|
||||
return f.then([cob = std::move(cob), opts = std::move(opts)](auto&& ret) {
|
||||
cql3_result_visitor visitor;
|
||||
ret->accept(visitor);
|
||||
|
||||
@@ -357,7 +357,7 @@ future<> trace_keyspace_helper::apply_events_mutation(lw_shared_ptr<one_session_
|
||||
return _events.cache_table_info(_dummy_query_state).then([this, records, &events_records] {
|
||||
tlogger.trace("{}: storing {} events records: parent_id {} span_id {}", records->session_id, events_records.size(), records->parent_id, records->my_span_id);
|
||||
|
||||
std::vector<shared_ptr<cql3::statements::modification_statement>> modifications(events_records.size(), _events.insert_stmt());
|
||||
std::vector<cql3::statements::batch_statement::single_statement> modifications(events_records.size(), cql3::statements::batch_statement::single_statement(_events.insert_stmt(), false));
|
||||
std::vector<std::vector<cql3::raw_value>> values;
|
||||
auto& qp = cql3::get_local_query_processor();
|
||||
|
||||
|
||||
@@ -910,7 +910,16 @@ future<response_type> cql_server::connection::process_execute(uint16_t stream, b
|
||||
{
|
||||
cql3::prepared_cache_key_type cache_key(read_short_bytes(buf));
|
||||
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
|
||||
auto prepared = _server._query_processor.local().get_prepared(cache_key);
|
||||
bool needs_authorization = false;
|
||||
|
||||
// First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
|
||||
// look for the prepared statement and then authorize it.
|
||||
auto prepared = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key);
|
||||
if (!prepared) {
|
||||
needs_authorization = true;
|
||||
prepared = _server._query_processor.local().get_prepared(cache_key);
|
||||
}
|
||||
|
||||
if (!prepared) {
|
||||
throw exceptions::prepared_query_not_found_exception(id);
|
||||
}
|
||||
@@ -945,7 +954,7 @@ future<response_type> cql_server::connection::process_execute(uint16_t stream, b
|
||||
throw exceptions::invalid_request_exception("Invalid amount of bind variables");
|
||||
}
|
||||
tracing::trace(query_state.get_trace_state(), "Processing a statement");
|
||||
return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) {
|
||||
return _server._query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), query_state, options, needs_authorization).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) {
|
||||
tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result");
|
||||
return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata);
|
||||
}).then([&query_state, q_state = std::move(q_state), this] (auto&& response) {
|
||||
@@ -964,8 +973,9 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
|
||||
const auto type = read_byte(buf);
|
||||
const unsigned n = read_short(buf);
|
||||
|
||||
std::vector<shared_ptr<cql3::statements::modification_statement>> modifications;
|
||||
std::vector<cql3::statements::batch_statement::single_statement> modifications;
|
||||
std::vector<std::vector<cql3::raw_value_view>> values;
|
||||
std::unordered_map<cql3::prepared_cache_key_type, cql3::authorized_prepared_statements_cache::value_type> pending_authorization_entries;
|
||||
|
||||
modifications.reserve(n);
|
||||
values.reserve(n);
|
||||
@@ -977,6 +987,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
|
||||
|
||||
std::unique_ptr<cql3::statements::prepared_statement> stmt_ptr;
|
||||
cql3::statements::prepared_statement::checked_weak_ptr ps;
|
||||
bool needs_authorization(kind == 0);
|
||||
|
||||
switch (kind) {
|
||||
case 0: {
|
||||
@@ -988,10 +999,19 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
|
||||
case 1: {
|
||||
cql3::prepared_cache_key_type cache_key(read_short_bytes(buf));
|
||||
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
|
||||
ps = _server._query_processor.local().get_prepared(cache_key);
|
||||
|
||||
// First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
|
||||
// look for the prepared statement and then authorize it.
|
||||
ps = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key);
|
||||
if (!ps) {
|
||||
throw exceptions::prepared_query_not_found_exception(id);
|
||||
ps = _server._query_processor.local().get_prepared(cache_key);
|
||||
if (!ps) {
|
||||
throw exceptions::prepared_query_not_found_exception(id);
|
||||
}
|
||||
// authorize a particular prepared statement only once
|
||||
needs_authorization = pending_authorization_entries.emplace(std::move(cache_key), ps->checked_weak_from_this()).second;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -1007,7 +1027,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
|
||||
::shared_ptr<cql3::statements::modification_statement> modif_statement_ptr = static_pointer_cast<cql3::statements::modification_statement>(ps->statement);
|
||||
tracing::add_table_name(client_state.get_trace_state(), modif_statement_ptr->keyspace(), modif_statement_ptr->column_family());
|
||||
|
||||
modifications.emplace_back(std::move(modif_statement_ptr));
|
||||
modifications.emplace_back(std::move(modif_statement_ptr), needs_authorization);
|
||||
|
||||
std::vector<cql3::raw_value_view> tmp;
|
||||
read_value_view_list(buf, tmp);
|
||||
@@ -1031,7 +1051,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
|
||||
tracing::trace(client_state.get_trace_state(), "Creating a batch statement");
|
||||
|
||||
auto batch = ::make_shared<cql3::statements::batch_statement>(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats());
|
||||
return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) {
|
||||
return _server._query_processor.local().process_batch(batch, query_state, options, std::move(pending_authorization_entries)).then([this, stream, batch, &query_state] (auto msg) {
|
||||
return this->make_result(stream, msg, query_state.get_trace_state());
|
||||
}).then([&query_state, q_state = std::move(q_state), this] (auto&& response) {
|
||||
/* Keep q_state alive. */
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include <unordered_map>
|
||||
#include <boost/intrusive/list.hpp>
|
||||
#include <boost/intrusive/unordered_set.hpp>
|
||||
#include <boost/intrusive/parent_from_member.hpp>
|
||||
|
||||
#include <seastar/core/reactor.hh>
|
||||
#include <seastar/core/timer.hh>
|
||||
@@ -79,6 +80,10 @@ public:
|
||||
value_type& value() noexcept { return _value; }
|
||||
const value_type& value() const noexcept { return _value; }
|
||||
|
||||
static const timestamped_val& container_of(const value_type& value) {
|
||||
return *bi::get_parent_from_member(&value, ×tamped_val::_value);
|
||||
}
|
||||
|
||||
loading_cache_clock_type::time_point last_read() const noexcept {
|
||||
return _last_read;
|
||||
}
|
||||
@@ -95,6 +100,10 @@ public:
|
||||
return _lru_entry_ptr;
|
||||
}
|
||||
|
||||
lru_entry* lru_entry_ptr() const noexcept {
|
||||
return _lru_entry_ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void touch() noexcept {
|
||||
assert(_lru_entry_ptr);
|
||||
@@ -365,6 +374,11 @@ public:
|
||||
return _timer_reads_gate.close().finally([this] { _timer.cancel(); });
|
||||
}
|
||||
|
||||
template<typename KeyType, typename KeyHasher, typename KeyEqual>
|
||||
iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
|
||||
return boost::make_transform_iterator(set_find(key, std::move(key_hasher_func), std::move(key_equal_func)), _value_extractor_fn);
|
||||
};
|
||||
|
||||
iterator find(const Key& k) noexcept {
|
||||
return boost::make_transform_iterator(set_find(k), _value_extractor_fn);
|
||||
}
|
||||
@@ -388,6 +402,24 @@ public:
|
||||
});
|
||||
}
|
||||
|
||||
void remove(const Key& k) {
|
||||
auto it = set_find(k);
|
||||
if (it == set_end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
_lru_list.erase_and_dispose(_lru_list.iterator_to(*it->lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); });
|
||||
}
|
||||
|
||||
void remove(iterator it) {
|
||||
if (it == end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ts_value_type& val = ts_value_type::container_of(*it);
|
||||
_lru_list.erase_and_dispose(_lru_list.iterator_to(*val.lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); });
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return _loading_values.size();
|
||||
}
|
||||
@@ -398,8 +430,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
set_iterator set_find(const Key& k) noexcept {
|
||||
set_iterator it = _loading_values.find(k);
|
||||
set_iterator ready_entry_iterator(set_iterator it) {
|
||||
set_iterator end_it = set_end();
|
||||
|
||||
if (it == end_it || !it->ready()) {
|
||||
@@ -408,6 +439,17 @@ private:
|
||||
return it;
|
||||
}
|
||||
|
||||
template<typename KeyType, typename KeyHasher, typename KeyEqual>
|
||||
set_iterator set_find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
|
||||
return ready_entry_iterator(_loading_values.find(key, std::move(key_hasher_func), std::move(key_equal_func)));
|
||||
}
|
||||
|
||||
// keep the default non-templated overloads to ease on the compiler for specifications
|
||||
// that do not require the templated find().
|
||||
set_iterator set_find(const Key& key) noexcept {
|
||||
return ready_entry_iterator(_loading_values.find(key));
|
||||
}
|
||||
|
||||
set_iterator set_end() noexcept {
|
||||
return _loading_values.end();
|
||||
}
|
||||
|
||||
@@ -107,16 +107,6 @@ private:
|
||||
return bool(_val);
|
||||
}
|
||||
|
||||
struct key_eq {
|
||||
bool operator()(const key_type& k, const entry& c) const {
|
||||
return EqualPred()(k, c.key());
|
||||
}
|
||||
|
||||
bool operator()(const entry& c, const key_type& k) const {
|
||||
return EqualPred()(c.key(), k);
|
||||
}
|
||||
};
|
||||
|
||||
entry(loading_shared_values& parent, key_type k)
|
||||
: _parent(parent), _key(std::move(k)) {}
|
||||
|
||||
@@ -134,6 +124,17 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template<typename KeyType, typename KeyEqual>
|
||||
struct key_eq {
|
||||
bool operator()(const KeyType& k, const entry& c) const {
|
||||
return KeyEqual()(k, c.key());
|
||||
}
|
||||
|
||||
bool operator()(const entry& c, const KeyType& k) const {
|
||||
return KeyEqual()(c.key(), k);
|
||||
}
|
||||
};
|
||||
|
||||
using set_type = bi::unordered_set<entry, bi::power_2_buckets<true>, bi::compare_hash<true>>;
|
||||
using bi_set_bucket_traits = typename set_type::bucket_traits;
|
||||
using set_iterator = typename set_type::iterator;
|
||||
@@ -216,7 +217,7 @@ public:
|
||||
future<entry_ptr> get_or_load(const key_type& key, Loader&& loader) noexcept {
|
||||
static_assert(std::is_same<future<value_type>, typename futurize<std::result_of_t<Loader(const key_type&)>>::type>::value, "Bad Loader signature");
|
||||
try {
|
||||
auto i = _set.find(key, Hash(), typename entry::key_eq());
|
||||
auto i = _set.find(key, Hash(), key_eq<key_type, EqualPred>());
|
||||
lw_shared_ptr<entry> e;
|
||||
future<> f = make_ready_future<>();
|
||||
if (i != _set.end()) {
|
||||
@@ -280,12 +281,19 @@ public:
|
||||
return boost::make_transform_iterator(_set.begin(), _value_extractor_fn);
|
||||
}
|
||||
|
||||
iterator find(const key_type& key) noexcept {
|
||||
set_iterator it = _set.find(key, Hash(), typename entry::key_eq());
|
||||
template<typename KeyType, typename KeyHasher, typename KeyEqual>
|
||||
iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
|
||||
set_iterator it = _set.find(key, std::move(key_hasher_func), key_eq<KeyType, KeyEqual>());
|
||||
if (it == _set.end() || !it->ready()) {
|
||||
return end();
|
||||
}
|
||||
return boost::make_transform_iterator(it, _value_extractor_fn);
|
||||
};
|
||||
|
||||
// keep the default non-templated overloads to ease on the compiler for specifications
|
||||
// that do not require the templated find().
|
||||
iterator find(const key_type& key) noexcept {
|
||||
return find(key, Hash(), EqualPred());
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user