diff --git a/audit/audit.cc b/audit/audit.cc index e455fee910..d7b1a3a26d 100644 --- a/audit/audit.cc +++ b/audit/audit.cc @@ -209,15 +209,11 @@ future<> audit::stop_audit() { }); } -audit_info_ptr audit::create_audit_info(statement_category cat, const sstring& keyspace, const sstring& table) { +audit_info_ptr audit::create_audit_info(statement_category cat, const sstring& keyspace, const sstring& table, bool batch) { if (!audit_instance().local_is_initialized()) { return nullptr; } - return std::make_unique(cat, keyspace, table); -} - -audit_info_ptr audit::create_no_audit_info() { - return audit_info_ptr(); + return std::make_unique(cat, keyspace, table, batch); } future<> audit::start(const db::config& cfg) { @@ -267,18 +263,21 @@ future<> audit::log_login(const sstring& username, socket_address client_ip, boo } future<> inspect(shared_ptr statement, service::query_state& query_state, const cql3::query_options& options, bool error) { - cql3::statements::batch_statement* batch = dynamic_cast(statement.get()); - if (batch != nullptr) { + auto audit_info = statement->get_audit_info(); + if (!audit_info) { + return make_ready_future<>(); + } + if (audit_info->batch()) { + cql3::statements::batch_statement* batch = static_cast(statement.get()); return do_for_each(batch->statements().begin(), batch->statements().end(), [&query_state, &options, error] (auto&& m) { return inspect(m.statement, query_state, options, error); }); } else { - auto audit_info = statement->get_audit_info(); - if (bool(audit_info) && audit::local_audit_instance().should_log(audit_info)) { + if (audit::local_audit_instance().should_log(audit_info)) { return audit::local_audit_instance().log(audit_info, query_state, options, error); } + return make_ready_future<>(); } - return make_ready_future<>(); } future<> inspect_login(const sstring& username, socket_address client_ip, bool error) { diff --git a/audit/audit.hh b/audit/audit.hh index 51c2c13344..5e24bccfaf 100644 --- a/audit/audit.hh +++ b/audit/audit.hh @@ -75,11 +75,13 @@ class audit_info final { sstring _keyspace; sstring _table; sstring _query; + bool _batch; public: - audit_info(statement_category cat, sstring keyspace, sstring table) + audit_info(statement_category cat, sstring keyspace, sstring table, bool batch) : _category(cat) , _keyspace(std::move(keyspace)) , _table(std::move(table)) + , _batch(batch) { } void set_query_string(const std::string_view& query_string) { _query = sstring(query_string); @@ -89,6 +91,7 @@ public: const sstring& query() const { return _query; } sstring category_string() const; statement_category category() const { return _category; } + bool batch() const { return _batch; } }; using audit_info_ptr = std::unique_ptr; @@ -126,8 +129,7 @@ public: } static future<> start_audit(const db::config& cfg, sharded& stm, sharded& qp, sharded& mm); static future<> stop_audit(); - static audit_info_ptr create_audit_info(statement_category cat, const sstring& keyspace, const sstring& table); - static audit_info_ptr create_no_audit_info(); + static audit_info_ptr create_audit_info(statement_category cat, const sstring& keyspace, const sstring& table, bool batch = false); audit(locator::shared_token_metadata& stm, cql3::query_processor& qp, service::migration_manager& mm, diff --git a/cql3/statements/raw/batch_statement.hh b/cql3/statements/raw/batch_statement.hh index cf2f0b83a0..1091d15ccd 100644 --- a/cql3/statements/raw/batch_statement.hh +++ b/cql3/statements/raw/batch_statement.hh @@ -50,8 +50,8 @@ public: protected: virtual audit::statement_category category() const override; virtual audit::audit_info_ptr audit_info() const override { - // We don't audit batch statements. Instead we audit statements that are inside the batch. - return audit::audit::create_no_audit_info(); + constexpr bool batch = true; + return audit::audit::create_audit_info(category(), sstring(), sstring(), batch); } };