diff --git a/cql3/authorized_prepared_statements_cache.hh b/cql3/authorized_prepared_statements_cache.hh
new file mode 100644
index 0000000000..4eff0b1c4a
--- /dev/null
+++ b/cql3/authorized_prepared_statements_cache.hh
@@ -0,0 +1,189 @@
+/*
+ * Copyright (C) 2018 ScyllaDB
+ */
+
+/*
+ * This file is part of Scylla.
+ *
+ * Scylla is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Scylla is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with Scylla. If not, see .
+ */
+
+#pragma once
+
+#include "cql3/prepared_statements_cache.hh"
+
+namespace cql3 {
+
+struct authorized_prepared_statements_cache_size {
+ size_t operator()(const statements::prepared_statement::checked_weak_ptr& val) {
+ // TODO: improve the size approximation - most of the entry is occupied by the key here.
+ return 100;
+ }
+};
+
+class authorized_prepared_statements_cache_key {
+public:
+ using cache_key_type = std::pair;
+private:
+ cache_key_type _key;
+
+public:
+ authorized_prepared_statements_cache_key(auth::authenticated_user user, cql3::prepared_cache_key_type prepared_cache_key)
+ : _key(std::move(user), std::move(prepared_cache_key.key())) {}
+
+ cache_key_type& key() { return _key; }
+
+ const cache_key_type& key() const { return _key; }
+
+ bool operator==(const authorized_prepared_statements_cache_key& other) const {
+ return _key == other._key;
+ }
+
+ bool operator!=(const authorized_prepared_statements_cache_key& other) const {
+ return !(*this == other);
+ }
+
+ static size_t hash(const auth::authenticated_user& user, const cql3::prepared_cache_key_type::cache_key_type& prep_cache_key) {
+ return utils::hash_combine(std::hash()(user), utils::tuple_hash()(prep_cache_key));
+ }
+};
+
+/// \class authorized_prepared_statements_cache
+/// \brief A cache of previously authorized statements.
+///
+/// Entries are inserted every time a new statement is authorized.
+/// Entries are evicted in any of the following cases:
+/// - When the corresponding prepared statement is not valid anymore.
+/// - Periodically, with the same period as the permission cache is refreshed.
+/// - If the corresponding entry hasn't been used for \ref entry_expiry.
+class authorized_prepared_statements_cache {
+public:
+ struct stats {
+ uint64_t authorized_prepared_statements_cache_evictions = 0;
+ };
+
+ static stats& shard_stats() {
+ static thread_local stats _stats;
+ return _stats;
+ }
+
+ struct authorized_prepared_statements_cache_stats_updater {
+ static void inc_hits() noexcept {}
+ static void inc_misses() noexcept {}
+ static void inc_blocks() noexcept {}
+ static void inc_evictions() noexcept {
+ ++shard_stats().authorized_prepared_statements_cache_evictions;
+ }
+ };
+
+private:
+ using cache_key_type = authorized_prepared_statements_cache_key;
+ using checked_weak_ptr = typename statements::prepared_statement::checked_weak_ptr;
+ using cache_type = utils::loading_cache,
+ std::equal_to,
+ authorized_prepared_statements_cache_stats_updater>;
+
+ static const std::chrono::minutes entry_expiry;
+
+public:
+ using key_type = cache_key_type;
+ using value_type = checked_weak_ptr;
+ using entry_is_too_big = typename cache_type::entry_is_too_big;
+ using iterator = typename cache_type::iterator;
+
+private:
+ cache_type _cache;
+ logging::logger& _logger;
+
+public:
+ // Choose the memory budget such that would allow us ~4K entries when a shard gets 1GB of RAM
+ authorized_prepared_statements_cache(std::chrono::milliseconds entry_refresh, logging::logger& logger)
+ : _cache(memory::stats().total_memory() / 2560, entry_expiry, entry_refresh, logger, [this] (const key_type& k) {
+ _cache.remove(k);
+ return make_ready_future();
+ })
+ , _logger(logger)
+ {}
+
+ future<> insert(auth::authenticated_user user, cql3::prepared_cache_key_type prep_cache_key, value_type v) noexcept {
+ return _cache.get_ptr(key_type(std::move(user), std::move(prep_cache_key)), [v = std::move(v)] (const cache_key_type&) mutable {
+ return make_ready_future(std::move(v));
+ }).discard_result();
+ }
+
+ iterator find(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) {
+ struct key_view {
+ const auth::authenticated_user& user_ref;
+ const cql3::prepared_cache_key_type& prep_cache_key_ref;
+ };
+
+ struct hasher {
+ size_t operator()(const key_view& kv) {
+ return cql3::authorized_prepared_statements_cache_key::hash(kv.user_ref, kv.prep_cache_key_ref.key());
+ }
+ };
+
+ struct equal {
+ bool operator()(const key_type& k1, const key_view& k2) {
+ return k1.key().first == k2.user_ref && k1.key().second == k2.prep_cache_key_ref.key();
+ }
+
+ bool operator()(const key_view& k2, const key_type& k1) {
+ return operator()(k1, k2);
+ }
+ };
+
+ return _cache.find(key_view{user, prep_cache_key}, hasher(), equal());
+ }
+
+ iterator end() {
+ return _cache.end();
+ }
+
+ void remove(const auth::authenticated_user& user, const cql3::prepared_cache_key_type& prep_cache_key) {
+ iterator it = find(user, prep_cache_key);
+ _cache.remove(it);
+ }
+
+ size_t size() const {
+ return _cache.size();
+ }
+
+ size_t memory_footprint() const {
+ return _cache.memory_footprint();
+ }
+
+ future<> stop() {
+ return _cache.stop();
+ }
+};
+
+}
+
+namespace std {
+template <>
+struct hash final {
+ size_t operator()(const cql3::authorized_prepared_statements_cache_key& k) const {
+ return cql3::authorized_prepared_statements_cache_key::hash(k.key().first, k.key().second);
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& out, const cql3::authorized_prepared_statements_cache_key& k) {
+ return out << "{ " << k.key().first << ", " << k.key().second << " }";
+}
+}
\ No newline at end of file
diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh
index c345fe520e..338186aff0 100644
--- a/cql3/prepared_statements_cache.hh
+++ b/cql3/prepared_statements_cache.hh
@@ -68,6 +68,14 @@ public:
static thrift_prepared_id_type thrift_id(const prepared_cache_key_type& key) {
return key.key().second;
}
+
+ bool operator==(const prepared_cache_key_type& other) const {
+ return _key == other._key;
+ }
+
+ bool operator!=(const prepared_cache_key_type& other) const {
+ return !(*this == other);
+ }
};
class prepared_statements_cache {
@@ -155,6 +163,10 @@ public:
size_t memory_footprint() const {
return _cache.memory_footprint();
}
+
+ future<> stop() {
+ return _cache.stop();
+ }
};
}
@@ -168,4 +180,11 @@ inline std::ostream& operator<<(std::ostream& os, const cql3::prepared_cache_key
os << p.key();
return os;
}
+
+template<>
+struct hash final {
+ size_t operator()(const cql3::prepared_cache_key_type& k) const {
+ return utils::tuple_hash()(k.key());
+ }
+};
}
diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc
index 693349653f..a6c657b7ac 100644
--- a/cql3/query_processor.cc
+++ b/cql3/query_processor.cc
@@ -58,12 +58,14 @@ using namespace cql_transport::messages;
logging::logger log("query_processor");
logging::logger prep_cache_log("prepared_statements_cache");
+logging::logger authorized_prepared_statements_cache_log("authorized_prepared_statements_cache");
distributed _the_query_processor;
const sstring query_processor::CQL_VERSION = "3.3.1";
const std::chrono::minutes prepared_statements_cache::entry_expiry = std::chrono::minutes(60);
+const std::chrono::minutes authorized_prepared_statements_cache::entry_expiry = std::chrono::minutes(60);
class query_processor::internal_state {
service::query_state _qs;
@@ -96,7 +98,8 @@ query_processor::query_processor(service::storage_proxy& proxy, distributed query_processor::stop() {
service::get_local_migration_manager().unregister_listener(_migration_subscriber.get());
- return make_ready_future<>();
+ return _authorized_prepared_cache.stop().finally([this] { return _prepared_cache.stop(); });
}
future<::shared_ptr>
@@ -230,33 +249,60 @@ query_processor::process(const sstring_view& query_string, service::query_state&
metrics.regularStatementsExecuted.inc();
#endif
tracing::trace(query_state.get_trace_state(), "Processing a statement");
- return process_statement(std::move(cql_statement), query_state, options);
+ return process_statement_unprepared(std::move(cql_statement), query_state, options);
}
future<::shared_ptr>
-query_processor::process_statement(
+query_processor::process_statement_unprepared(
::shared_ptr statement,
service::query_state& query_state,
const query_options& options) {
- return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options]() {
- auto& client_state = query_state.get_client_state();
+ return statement->check_access(query_state.get_client_state()).then([this, statement, &query_state, &options] () mutable {
+ return process_authorized_statement(std::move(statement), query_state, options);
+ });
+}
- statement->validate(_proxy, client_state);
+future<::shared_ptr>
+query_processor::process_statement_prepared(
+ statements::prepared_statement::checked_weak_ptr prepared,
+ cql3::prepared_cache_key_type cache_key,
+ service::query_state& query_state,
+ const query_options& options,
+ bool needs_authorization) {
- auto fut = make_ready_future<::shared_ptr>();
- if (client_state.is_internal()) {
- fut = statement->execute_internal(_proxy, query_state, options);
- } else {
- fut = statement->execute(_proxy, query_state, options);
- }
-
- return fut.then([statement] (auto msg) {
- if (msg) {
- return make_ready_future<::shared_ptr>(std::move(msg));
- }
- return make_ready_future<::shared_ptr>(
- ::make_shared());
+ ::shared_ptr statement = prepared->statement;
+ future<> fut = make_ready_future<>();
+ if (needs_authorization) {
+ fut = statement->check_access(query_state.get_client_state()).then([this, &query_state, prepared = std::move(prepared), cache_key = std::move(cache_key)] () mutable {
+ return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), std::move(cache_key), std::move(prepared)).handle_exception([this] (auto eptr) {
+ log.error("failed to cache the entry", eptr);
+ });
});
+ }
+
+ return fut.then([this, statement = std::move(statement), &query_state, &options] () mutable {
+ return process_authorized_statement(std::move(statement), query_state, options);
+ });
+}
+
+future<::shared_ptr>
+query_processor::process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options) {
+ auto& client_state = query_state.get_client_state();
+
+ statement->validate(_proxy, client_state);
+
+ auto fut = make_ready_future<::shared_ptr>();
+ if (client_state.is_internal()) {
+ fut = statement->execute_internal(_proxy, query_state, options);
+ } else {
+ fut = statement->execute(_proxy, query_state, options);
+ }
+
+ return fut.then([statement] (auto msg) {
+ if (msg) {
+ return make_ready_future<::shared_ptr>(std::move(msg));
+ }
+ return make_ready_future<::shared_ptr>(::make_shared());
});
}
@@ -558,11 +604,18 @@ future<::shared_ptr>
query_processor::process_batch(
::shared_ptr batch,
service::query_state& query_state,
- query_options& options) {
- return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch] {
- batch->validate();
- batch->validate(_proxy, query_state.get_client_state());
- return batch->execute(_proxy, query_state, options);
+ query_options& options,
+ std::unordered_map pending_authorization_entries) {
+ return batch->check_access(query_state.get_client_state()).then([this, &query_state, &options, batch, pending_authorization_entries = std::move(pending_authorization_entries)] () mutable {
+ return parallel_for_each(pending_authorization_entries, [this, &query_state] (auto& e) {
+ return _authorized_prepared_cache.insert(*query_state.get_client_state().user(), e.first, std::move(e.second)).handle_exception([this] (auto eptr) {
+ log.error("failed to cache the entry", eptr);
+ });
+ }).then([this, &query_state, &options, batch] {
+ batch->validate();
+ batch->validate(_proxy, query_state.get_client_state());
+ return batch->execute(_proxy, query_state, options);
+ });
});
}
diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh
index 1c6a2b1b18..2f671e0c96 100644
--- a/cql3/query_processor.hh
+++ b/cql3/query_processor.hh
@@ -49,6 +49,7 @@
#include
#include "cql3/prepared_statements_cache.hh"
+#include "cql3/authorized_prepared_statements_cache.hh"
#include "cql3/query_options.hh"
#include "cql3/statements/prepared_statement.hh"
#include "cql3/statements/raw/parsed_statement.hh"
@@ -117,6 +118,7 @@ private:
std::unique_ptr _internal_state;
prepared_statements_cache _prepared_cache;
+ authorized_prepared_statements_cache _authorized_prepared_cache;
// A map for prepared statements used internally (which we don't want to mix with user statement, in particular we
// don't bother with expiration on those.
@@ -151,6 +153,21 @@ public:
return _cql_stats;
}
+ statements::prepared_statement::checked_weak_ptr get_prepared(const auth::authenticated_user* user_ptr, const prepared_cache_key_type& key) {
+ if (user_ptr) {
+ auto it = _authorized_prepared_cache.find(*user_ptr, key);
+ if (it != _authorized_prepared_cache.end()) {
+ try {
+ return it->get()->checked_weak_from_this();
+ } catch (seastar::checked_ptr_is_null_exception&) {
+ // If the prepared statement got invalidated - remove the corresponding authorized_prepared_statements_cache entry as well.
+ _authorized_prepared_cache.remove(*user_ptr, key);
+ }
+ }
+ }
+ return statements::prepared_statement::checked_weak_ptr();
+ }
+
statements::prepared_statement::checked_weak_ptr get_prepared(const prepared_cache_key_type& key) {
auto it = _prepared_cache.find(key);
if (it == _prepared_cache.end()) {
@@ -160,11 +177,19 @@ public:
}
future<::shared_ptr>
- process_statement(
+ process_statement_unprepared(
::shared_ptr statement,
service::query_state& query_state,
const query_options& options);
+ future<::shared_ptr>
+ process_statement_prepared(
+ statements::prepared_statement::checked_weak_ptr statement,
+ cql3::prepared_cache_key_type cache_key,
+ service::query_state& query_state,
+ const query_options& options,
+ bool needs_authorization);
+
future<::shared_ptr>
process(
const std::experimental::string_view& query_string,
@@ -242,7 +267,11 @@ public:
future<> stop();
future<::shared_ptr>
- process_batch(::shared_ptr, service::query_state& query_state, query_options& options);
+ process_batch(
+ ::shared_ptr,
+ service::query_state& query_state,
+ query_options& options,
+ std::unordered_map pending_authorization_entries);
std::unique_ptr get_statement(
const std::experimental::string_view& query,
@@ -257,6 +286,9 @@ private:
db::consistency_level = db::consistency_level::ONE,
int32_t page_size = -1);
+ future<::shared_ptr>
+ process_authorized_statement(const ::shared_ptr statement, service::query_state& query_state, const query_options& options);
+
/*!
* \brief created a state object for paging
*
diff --git a/cql3/statements/batch_statement.cc b/cql3/statements/batch_statement.cc
index 96d9db1f13..8916781b95 100644
--- a/cql3/statements/batch_statement.cc
+++ b/cql3/statements/batch_statement.cc
@@ -75,19 +75,19 @@ timeout_for_type(batch_statement::type t) {
}
batch_statement::batch_statement(int bound_terms, type type_,
- std::vector> statements,
+ std::vector statements,
std::unique_ptr attrs,
cql_stats& stats)
: cql_statement_no_metadata(timeout_for_type(type_))
, _bound_terms(bound_terms), _type(type_), _statements(std::move(statements))
, _attrs(std::move(attrs))
- , _has_conditions(boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::has_conditions)))
+ , _has_conditions(boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->has_conditions(); }))
, _stats(stats)
{
}
batch_statement::batch_statement(type type_,
- std::vector> statements,
+ std::vector statements,
std::unique_ptr attrs,
cql_stats& stats)
: batch_statement(-1, type_, std::move(statements), std::move(attrs), stats)
@@ -97,7 +97,7 @@ batch_statement::batch_statement(type type_,
bool batch_statement::uses_function(const sstring& ks_name, const sstring& function_name) const
{
return _attrs->uses_function(ks_name, function_name)
- || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s->uses_function(ks_name, function_name); });
+ || boost::algorithm::any_of(_statements, [&] (auto&& s) { return s.statement->uses_function(ks_name, function_name); });
}
bool batch_statement::depends_on_keyspace(const sstring& ks_name) const
@@ -118,7 +118,11 @@ uint32_t batch_statement::get_bound_terms()
future<> batch_statement::check_access(const service::client_state& state)
{
return parallel_for_each(_statements.begin(), _statements.end(), [&state](auto&& s) {
- return s->check_access(state);
+ if (s.needs_authorization) {
+ return s.statement->check_access(state);
+ } else {
+ return make_ready_future<>();
+ }
});
}
@@ -138,12 +142,12 @@ void batch_statement::validate()
}
}
- bool has_counters = boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_counter));
- bool has_non_counters = !boost::algorithm::all_of(_statements, std::mem_fn(&modification_statement::is_counter));
+ bool has_counters = boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_counter(); });
+ bool has_non_counters = !boost::algorithm::all_of(_statements, [] (auto&& s) { return s.statement->is_counter(); });
if (timestamp_set && has_counters) {
throw exceptions::invalid_request_exception("Cannot provide custom timestamp for a BATCH containing counters");
}
- if (timestamp_set && boost::algorithm::any_of(_statements, std::mem_fn(&modification_statement::is_timestamp_set))) {
+ if (timestamp_set && boost::algorithm::any_of(_statements, [] (auto&& s) { return s.statement->is_timestamp_set(); })) {
throw exceptions::invalid_request_exception("Timestamp must be set either on BATCH or individual statements");
}
if (_type == type::COUNTER && has_non_counters) {
@@ -159,30 +163,30 @@ void batch_statement::validate()
if (_has_conditions
&& !_statements.empty()
&& (boost::distance(_statements
- | boost::adaptors::transformed(std::mem_fn(&modification_statement::keyspace))
+ | boost::adaptors::transformed([] (auto&& s) { return s.statement->keyspace(); })
| boost::adaptors::uniqued) != 1
|| (boost::distance(_statements
- | boost::adaptors::transformed(std::mem_fn(&modification_statement::column_family))
+ | boost::adaptors::transformed([] (auto&& s) { return s.statement->column_family(); })
| boost::adaptors::uniqued) != 1))) {
throw exceptions::invalid_request_exception("Batch with conditions cannot span multiple tables");
}
std::experimental::optional raw_counter;
for (auto& s : _statements) {
- if (raw_counter && s->is_raw_counter_shard_write() != *raw_counter) {
+ if (raw_counter && s.statement->is_raw_counter_shard_write() != *raw_counter) {
throw exceptions::invalid_request_exception("Cannot mix raw and regular counter statements in batch");
}
- raw_counter = s->is_raw_counter_shard_write();
+ raw_counter = s.statement->is_raw_counter_shard_write();
}
}
void batch_statement::validate(service::storage_proxy& proxy, const service::client_state& state)
{
for (auto&& s : _statements) {
- s->validate(proxy, state);
+ s.statement->validate(proxy, state);
}
}
-const std::vector>& batch_statement::get_statements()
+const std::vector& batch_statement::get_statements()
{
return _statements;
}
@@ -196,7 +200,7 @@ future> batch_statement::get_mutations(service::storage_pr
return do_for_each(boost::make_counting_iterator(0),
boost::make_counting_iterator(_statements.size()),
[this, &storage, &options, now, local, &result, trace_state] (size_t i) {
- auto&& statement = _statements[i];
+ auto&& statement = _statements[i].statement;
statement->inc_cql_stats();
auto&& statement_options = options.for_statement(i);
auto timestamp = _attrs->get_timestamp(now, statement_options);
@@ -426,7 +430,9 @@ batch_statement::prepare(database& db, cql_stats& stats) {
stdx::optional first_cf;
bool have_multiple_cfs = false;
- std::vector> statements;
+ std::vector statements;
+ statements.reserve(_parsed_statements.size());
+
for (auto&& parsed : _parsed_statements) {
if (!first_ks) {
first_ks = parsed->keyspace();
@@ -434,7 +440,7 @@ batch_statement::prepare(database& db, cql_stats& stats) {
} else {
have_multiple_cfs = first_ks.value() != parsed->keyspace() || first_cf.value() != parsed->column_family();
}
- statements.push_back(parsed->prepare(db, bound_names, stats));
+ statements.emplace_back(parsed->prepare(db, bound_names, stats));
}
auto&& prep_attrs = _attrs->prepare(db, "[batch]", "[batch]");
@@ -445,7 +451,7 @@ batch_statement::prepare(database& db, cql_stats& stats) {
std::vector partition_key_bind_indices;
if (!have_multiple_cfs && batch_statement_.get_statements().size() > 0) {
- partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0]->s);
+ partition_key_bind_indices = bound_names->get_partition_key_bind_indexes(batch_statement_.get_statements()[0].statement->s);
}
return std::make_unique(make_shared(std::move(batch_statement_)),
bound_names->get_specifications(),
diff --git a/cql3/statements/batch_statement.hh b/cql3/statements/batch_statement.hh
index cc0d4595b9..f31d24e79f 100644
--- a/cql3/statements/batch_statement.hh
+++ b/cql3/statements/batch_statement.hh
@@ -66,10 +66,24 @@ class batch_statement : public cql_statement_no_metadata {
static logging::logger _logger;
public:
using type = raw::batch_statement::type;
+
+ struct single_statement {
+ shared_ptr statement;
+ bool needs_authorization = true;
+
+ public:
+ single_statement(shared_ptr s)
+ : statement(std::move(s))
+ {}
+ single_statement(shared_ptr s, bool na)
+ : statement(std::move(s))
+ , needs_authorization(na)
+ {}
+ };
private:
int _bound_terms;
type _type;
- std::vector> _statements;
+ std::vector _statements;
std::unique_ptr _attrs;
bool _has_conditions;
cql_stats& _stats;
@@ -83,12 +97,12 @@ public:
* @param attrs additional attributes for statement (CL, timestamp, timeToLive)
*/
batch_statement(int bound_terms, type type_,
- std::vector> statements,
+ std::vector statements,
std::unique_ptr attrs,
cql_stats& stats);
batch_statement(type type_,
- std::vector> statements,
+ std::vector statements,
std::unique_ptr attrs,
cql_stats& stats);
@@ -109,7 +123,7 @@ public:
// or in QueryProcessor.processBatch() - for native protocol batches.
virtual void validate(service::storage_proxy& proxy, const service::client_state& state) override;
- const std::vector>& get_statements();
+ const std::vector& get_statements();
private:
future> get_mutations(service::storage_proxy& storage, const query_options& options, bool local, api::timestamp_type now, tracing::trace_state_ptr trace_state);
diff --git a/tests/cql_test_env.cc b/tests/cql_test_env.cc
index c3080c2dc3..59743618da 100644
--- a/tests/cql_test_env.cc
+++ b/tests/cql_test_env.cc
@@ -166,7 +166,7 @@ public:
options->prepare(prepared->bound_names);
auto qs = make_query_state();
- return local_qp().process_statement(stmt, *qs, *options)
+ return local_qp().process_statement_prepared(std::move(prepared), std::move(id), *qs, *options, true)
.finally([options, qs, this] {
_core_local.local().client_state.merge(qs->get_client_state());
});
diff --git a/tests/loading_cache_test.cc b/tests/loading_cache_test.cc
index 4e320eb8f5..51993f3046 100644
--- a/tests/loading_cache_test.cc
+++ b/tests/loading_cache_test.cc
@@ -229,6 +229,41 @@ SEASTAR_TEST_CASE(test_loading_cache_loading_same_key) {
});
}
+SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_key) {
+ using namespace std::chrono;
+ load_count = 0;
+ utils::loading_cache loading_cache(num_loaders, 100s, test_logger);
+ auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); });
+
+ prepare().get();
+
+ loading_cache.get_ptr(0, loader).discard_result().get();
+ BOOST_REQUIRE_EQUAL(load_count, 1);
+ BOOST_REQUIRE(loading_cache.find(0) != loading_cache.end());
+
+ loading_cache.remove(0);
+ BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end());
+}
+
+SEASTAR_THREAD_TEST_CASE(test_loading_cache_removing_iterator) {
+ using namespace std::chrono;
+ load_count = 0;
+ utils::loading_cache loading_cache(num_loaders, 100s, test_logger);
+ auto stop_cache_reload = seastar::defer([&loading_cache] { loading_cache.stop().get(); });
+
+ prepare().get();
+
+ loading_cache.get_ptr(0, loader).discard_result().get();
+ BOOST_REQUIRE_EQUAL(load_count, 1);
+
+ auto it = loading_cache.find(0);
+
+ BOOST_REQUIRE(it != loading_cache.end());
+
+ loading_cache.remove(it);
+ BOOST_REQUIRE(loading_cache.find(0) == loading_cache.end());
+}
+
SEASTAR_TEST_CASE(test_loading_cache_loading_different_keys) {
return seastar::async([] {
using namespace std::chrono;
diff --git a/thrift/handler.cc b/thrift/handler.cc
index 815a449970..44f0c72811 100644
--- a/thrift/handler.cc
+++ b/thrift/handler.cc
@@ -999,9 +999,17 @@ public:
void execute_prepared_cql3_query(tcxx::function cob, tcxx::function exn_cob, const int32_t itemId, const std::vector & values, const ConsistencyLevel::type consistency) {
with_exn_cob(std::move(exn_cob), [&] {
- auto prepared = _query_processor.local().get_prepared(cql3::prepared_cache_key_type(itemId));
+ cql3::prepared_cache_key_type cache_key(itemId);
+ bool needs_authorization = false;
+
+ auto prepared = _query_processor.local().get_prepared(_query_state.get_client_state().user().get(), cache_key);
if (!prepared) {
- throw make_exception("Prepared query with id %d not found", itemId);
+ needs_authorization = true;
+
+ prepared = _query_processor.local().get_prepared(cache_key);
+ if (!prepared) {
+ throw make_exception("Prepared query with id %d not found", itemId);
+ }
}
auto stmt = prepared->statement;
if (stmt->get_bound_terms() != values.size()) {
@@ -1013,7 +1021,7 @@ public:
});
auto opts = std::make_unique(cl_from_thrift(consistency), _timeout_config, stdx::nullopt, std::move(bytes_values),
false, cql3::query_options::specific_options::DEFAULT, cql_serialization_format::latest());
- auto f = _query_processor.local().process_statement(stmt, _query_state, *opts);
+ auto f = _query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), _query_state, *opts, needs_authorization);
return f.then([cob = std::move(cob), opts = std::move(opts)](auto&& ret) {
cql3_result_visitor visitor;
ret->accept(visitor);
diff --git a/tracing/trace_keyspace_helper.cc b/tracing/trace_keyspace_helper.cc
index 63c1d42de9..1aa5236a5f 100644
--- a/tracing/trace_keyspace_helper.cc
+++ b/tracing/trace_keyspace_helper.cc
@@ -357,7 +357,7 @@ future<> trace_keyspace_helper::apply_events_mutation(lw_shared_ptrsession_id, events_records.size(), records->parent_id, records->my_span_id);
- std::vector> modifications(events_records.size(), _events.insert_stmt());
+ std::vector modifications(events_records.size(), cql3::statements::batch_statement::single_statement(_events.insert_stmt(), false));
std::vector> values;
auto& qp = cql3::get_local_query_processor();
diff --git a/transport/server.cc b/transport/server.cc
index e21bf63bf5..c4233d0644 100644
--- a/transport/server.cc
+++ b/transport/server.cc
@@ -910,7 +910,16 @@ future cql_server::connection::process_execute(uint16_t stream, b
{
cql3::prepared_cache_key_type cache_key(read_short_bytes(buf));
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
- auto prepared = _server._query_processor.local().get_prepared(cache_key);
+ bool needs_authorization = false;
+
+ // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
+ // look for the prepared statement and then authorize it.
+ auto prepared = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key);
+ if (!prepared) {
+ needs_authorization = true;
+ prepared = _server._query_processor.local().get_prepared(cache_key);
+ }
+
if (!prepared) {
throw exceptions::prepared_query_not_found_exception(id);
}
@@ -945,7 +954,7 @@ future cql_server::connection::process_execute(uint16_t stream, b
throw exceptions::invalid_request_exception("Invalid amount of bind variables");
}
tracing::trace(query_state.get_trace_state(), "Processing a statement");
- return _server._query_processor.local().process_statement(stmt, query_state, options).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) {
+ return _server._query_processor.local().process_statement_prepared(std::move(prepared), std::move(cache_key), query_state, options, needs_authorization).then([this, stream, buf = std::move(buf), &query_state, skip_metadata] (auto msg) {
tracing::trace(query_state.get_trace_state(), "Done processing - preparing a result");
return this->make_result(stream, msg, query_state.get_trace_state(), skip_metadata);
}).then([&query_state, q_state = std::move(q_state), this] (auto&& response) {
@@ -964,8 +973,9 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
const auto type = read_byte(buf);
const unsigned n = read_short(buf);
- std::vector> modifications;
+ std::vector modifications;
std::vector> values;
+ std::unordered_map pending_authorization_entries;
modifications.reserve(n);
values.reserve(n);
@@ -977,6 +987,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
std::unique_ptr stmt_ptr;
cql3::statements::prepared_statement::checked_weak_ptr ps;
+ bool needs_authorization(kind == 0);
switch (kind) {
case 0: {
@@ -988,10 +999,19 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
case 1: {
cql3::prepared_cache_key_type cache_key(read_short_bytes(buf));
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
- ps = _server._query_processor.local().get_prepared(cache_key);
+
+ // First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
+ // look for the prepared statement and then authorize it.
+ ps = _server._query_processor.local().get_prepared(client_state.user().get(), cache_key);
if (!ps) {
- throw exceptions::prepared_query_not_found_exception(id);
+ ps = _server._query_processor.local().get_prepared(cache_key);
+ if (!ps) {
+ throw exceptions::prepared_query_not_found_exception(id);
+ }
+ // authorize a particular prepared statement only once
+ needs_authorization = pending_authorization_entries.emplace(std::move(cache_key), ps->checked_weak_from_this()).second;
}
+
break;
}
default:
@@ -1007,7 +1027,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
::shared_ptr modif_statement_ptr = static_pointer_cast(ps->statement);
tracing::add_table_name(client_state.get_trace_state(), modif_statement_ptr->keyspace(), modif_statement_ptr->column_family());
- modifications.emplace_back(std::move(modif_statement_ptr));
+ modifications.emplace_back(std::move(modif_statement_ptr), needs_authorization);
std::vector tmp;
read_value_view_list(buf, tmp);
@@ -1031,7 +1051,7 @@ cql_server::connection::process_batch(uint16_t stream, bytes_view buf, service::
tracing::trace(client_state.get_trace_state(), "Creating a batch statement");
auto batch = ::make_shared(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), _server._query_processor.local().get_cql_stats());
- return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch, &query_state] (auto msg) {
+ return _server._query_processor.local().process_batch(batch, query_state, options, std::move(pending_authorization_entries)).then([this, stream, batch, &query_state] (auto msg) {
return this->make_result(stream, msg, query_state.get_trace_state());
}).then([&query_state, q_state = std::move(q_state), this] (auto&& response) {
/* Keep q_state alive. */
diff --git a/utils/loading_cache.hh b/utils/loading_cache.hh
index d8a96ce2e9..8ae74de30a 100644
--- a/utils/loading_cache.hh
+++ b/utils/loading_cache.hh
@@ -25,6 +25,7 @@
#include
#include
#include
+#include
#include
#include
@@ -79,6 +80,10 @@ public:
value_type& value() noexcept { return _value; }
const value_type& value() const noexcept { return _value; }
+ static const timestamped_val& container_of(const value_type& value) {
+ return *bi::get_parent_from_member(&value, ×tamped_val::_value);
+ }
+
loading_cache_clock_type::time_point last_read() const noexcept {
return _last_read;
}
@@ -95,6 +100,10 @@ public:
return _lru_entry_ptr;
}
+ lru_entry* lru_entry_ptr() const noexcept {
+ return _lru_entry_ptr;
+ }
+
private:
void touch() noexcept {
assert(_lru_entry_ptr);
@@ -365,6 +374,11 @@ public:
return _timer_reads_gate.close().finally([this] { _timer.cancel(); });
}
+ template
+ iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
+ return boost::make_transform_iterator(set_find(key, std::move(key_hasher_func), std::move(key_equal_func)), _value_extractor_fn);
+ };
+
iterator find(const Key& k) noexcept {
return boost::make_transform_iterator(set_find(k), _value_extractor_fn);
}
@@ -388,6 +402,24 @@ public:
});
}
+ void remove(const Key& k) {
+ auto it = set_find(k);
+ if (it == set_end()) {
+ return;
+ }
+
+ _lru_list.erase_and_dispose(_lru_list.iterator_to(*it->lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); });
+ }
+
+ void remove(iterator it) {
+ if (it == end()) {
+ return;
+ }
+
+ const ts_value_type& val = ts_value_type::container_of(*it);
+ _lru_list.erase_and_dispose(_lru_list.iterator_to(*val.lru_entry_ptr()), [this] (ts_value_lru_entry* p) { loading_cache::destroy_ts_value(p); });
+ }
+
size_t size() const {
return _loading_values.size();
}
@@ -398,8 +430,7 @@ public:
}
private:
- set_iterator set_find(const Key& k) noexcept {
- set_iterator it = _loading_values.find(k);
+ set_iterator ready_entry_iterator(set_iterator it) {
set_iterator end_it = set_end();
if (it == end_it || !it->ready()) {
@@ -408,6 +439,17 @@ private:
return it;
}
+ template
+ set_iterator set_find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
+ return ready_entry_iterator(_loading_values.find(key, std::move(key_hasher_func), std::move(key_equal_func)));
+ }
+
+ // keep the default non-templated overloads to ease on the compiler for specifications
+ // that do not require the templated find().
+ set_iterator set_find(const Key& key) noexcept {
+ return ready_entry_iterator(_loading_values.find(key));
+ }
+
set_iterator set_end() noexcept {
return _loading_values.end();
}
diff --git a/utils/loading_shared_values.hh b/utils/loading_shared_values.hh
index bfa6a276ff..55cc7fcae7 100644
--- a/utils/loading_shared_values.hh
+++ b/utils/loading_shared_values.hh
@@ -107,16 +107,6 @@ private:
return bool(_val);
}
- struct key_eq {
- bool operator()(const key_type& k, const entry& c) const {
- return EqualPred()(k, c.key());
- }
-
- bool operator()(const entry& c, const key_type& k) const {
- return EqualPred()(c.key(), k);
- }
- };
-
entry(loading_shared_values& parent, key_type k)
: _parent(parent), _key(std::move(k)) {}
@@ -134,6 +124,17 @@ private:
}
};
+ template
+ struct key_eq {
+ bool operator()(const KeyType& k, const entry& c) const {
+ return KeyEqual()(k, c.key());
+ }
+
+ bool operator()(const entry& c, const KeyType& k) const {
+ return KeyEqual()(c.key(), k);
+ }
+ };
+
using set_type = bi::unordered_set, bi::compare_hash>;
using bi_set_bucket_traits = typename set_type::bucket_traits;
using set_iterator = typename set_type::iterator;
@@ -216,7 +217,7 @@ public:
future get_or_load(const key_type& key, Loader&& loader) noexcept {
static_assert(std::is_same, typename futurize>::type>::value, "Bad Loader signature");
try {
- auto i = _set.find(key, Hash(), typename entry::key_eq());
+ auto i = _set.find(key, Hash(), key_eq());
lw_shared_ptr e;
future<> f = make_ready_future<>();
if (i != _set.end()) {
@@ -280,12 +281,19 @@ public:
return boost::make_transform_iterator(_set.begin(), _value_extractor_fn);
}
- iterator find(const key_type& key) noexcept {
- set_iterator it = _set.find(key, Hash(), typename entry::key_eq());
+ template
+ iterator find(const KeyType& key, KeyHasher key_hasher_func, KeyEqual key_equal_func) noexcept {
+ set_iterator it = _set.find(key, std::move(key_hasher_func), key_eq());
if (it == _set.end() || !it->ready()) {
return end();
}
return boost::make_transform_iterator(it, _value_extractor_fn);
+ };
+
+ // keep the default non-templated overloads to ease on the compiler for specifications
+ // that do not require the templated find().
+ iterator find(const key_type& key) noexcept {
+ return find(key, Hash(), EqualPred());
}
private: