diff --git a/cql3/query_options.cc b/cql3/query_options.cc index 5fdb42766f..0048a4fc4d 100644 --- a/cql3/query_options.cc +++ b/cql3/query_options.cc @@ -71,6 +71,17 @@ query_options::query_options(db::consistency_level consistency, { } +query_options::query_options(query_options&& o, std::vector> value_views) + : query_options(std::move(o)) +{ + std::vector tmp; + tmp.reserve(value_views.size()); + std::transform(value_views.begin(), value_views.end(), std::back_inserter(tmp), [this](auto& vals) { + return query_options(_consistency, {}, vals, _skip_metadata, _options, _protocol_version, _serialization_format); + }); + _batch_options = std::move(tmp); +} + query_options::query_options(std::vector values) : query_options( db::consistency_level::ONE, diff --git a/cql3/query_options.hh b/cql3/query_options.hh index 60977ffe6f..8061de75df 100644 --- a/cql3/query_options.hh +++ b/cql3/query_options.hh @@ -61,6 +61,9 @@ private: serialization_format _serialization_format; std::experimental::optional> _batch_options; public: + query_options(query_options&&) = default; + query_options(const query_options&) = delete; + explicit query_options(db::consistency_level consistency, std::experimental::optional> names, std::vector values, @@ -77,6 +80,16 @@ public: int32_t protocol_version, serialization_format sf); + explicit query_options(db::consistency_level consistency, + std::vector> value_views, + bool skip_metadata, + specific_options options, + int32_t protocol_version, + serialization_format sf); + + // Batch query_options constructor + explicit query_options(query_options&&, std::vector> value_views); + // It can't be const because of prepare() static thread_local query_options DEFAULT; diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index a3ddc97a2c..d8b0fe3709 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -25,6 +25,7 @@ #include "cql3/query_processor.hh" #include "cql3/CqlParser.hpp" #include "cql3/error_collector.hh" +#include "cql3/statements/batch_statement.hh" #include "transport/messages/result_message.hh" @@ -311,4 +312,14 @@ future<::shared_ptr> query_processor::execute_internal( }); } +future<::shared_ptr> +query_processor::process_batch(::shared_ptr batch, service::query_state& query_state, query_options& options) { + auto& client_state = query_state.get_client_state(); + batch->check_access(client_state); + batch->validate(); + batch->validate(_proxy, client_state); + return batch->execute(_proxy, query_state, options); +} + + } diff --git a/cql3/query_processor.hh b/cql3/query_processor.hh index 274515823e..4cc0a3e2bc 100644 --- a/cql3/query_processor.hh +++ b/cql3/query_processor.hh @@ -39,6 +39,10 @@ namespace cql3 { +namespace statements { +class batch_statement; +} + class query_processor { private: distributed& _proxy; @@ -425,19 +429,12 @@ private: metrics.preparedStatementsExecuted.inc(); return processStatement(statement, queryState, options); } - - public ResultMessage processBatch(BatchStatement batch, QueryState queryState, BatchQueryOptions options) - throws RequestExecutionException, RequestValidationException - { - ClientState clientState = queryState.getClientState(); - batch.checkAccess(clientState); - batch.validate(); - batch.validate(clientState); - return batch.execute(queryState, options); - } #endif public: + future<::shared_ptr> process_batch(::shared_ptr, + service::query_state& query_state, query_options& options); + ::shared_ptr get_statement(const std::experimental::string_view& query, const service::client_state& client_state); static ::shared_ptr parse_statement(const std::experimental::string_view& query); diff --git a/streaming/stream_session.cc b/streaming/stream_session.cc index 38c41732c8..a3f7c0431d 100644 --- a/streaming/stream_session.cc +++ b/streaming/stream_session.cc @@ -214,24 +214,24 @@ future<> stream_session::test(distributed& qp) { sslog.debug("================ STREAM_PLAN TEST =============="); auto cs = service::client_state::for_external_calls(); service::query_state qs(cs); - auto opts = make_shared(cql3::query_options::DEFAULT); - qp.local().process("CREATE KEYSPACE ks WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };", qs, *opts).get(); + auto& opts = cql3::query_options::DEFAULT; + qp.local().process("CREATE KEYSPACE ks WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };", qs, opts).get(); sslog.debug("CREATE KEYSPACE = KS DONE"); sleep(std::chrono::seconds(3)).get(); - qp.local().process("CREATE TABLE ks.tb ( key text PRIMARY KEY, C0 text, C1 text, C2 text, C3 blob, C4 text);", qs, *opts).get(); + qp.local().process("CREATE TABLE ks.tb ( key text PRIMARY KEY, C0 text, C1 text, C2 text, C3 blob, C4 text);", qs, opts).get(); sslog.debug("CREATE TABLE = TB DONE"); sleep(std::chrono::seconds(3)).get(); - qp.local().process("insert into ks.tb (key,c0) values ('1','1');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('1','1');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 1"); - qp.local().process("insert into ks.tb (key,c0) values ('2','2');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('2','2');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 2"); - qp.local().process("insert into ks.tb (key,c0) values ('3','3');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('3','3');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 3"); - qp.local().process("insert into ks.tb (key,c0) values ('4','4');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('4','4');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 4"); - qp.local().process("insert into ks.tb (key,c0) values ('5','5');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('5','5');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 5"); - qp.local().process("insert into ks.tb (key,c0) values ('6','6');", qs, *opts).get(); + qp.local().process("insert into ks.tb (key,c0) values ('6','6');", qs, opts).get(); sslog.debug("INSERT VALUE DONE: 6"); }).then([] { sleep(std::chrono::seconds(10)).then([] { diff --git a/transport/server.cc b/transport/server.cc index 39db160205..b5d9ad8f72 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -11,6 +11,7 @@ #include #include +#include "cql3/statements/batch_statement.hh" #include "service/migration_manager.hh" #include "service/storage_service.hh" #include "db/consistency_level.hh" @@ -533,8 +534,71 @@ future<> cql_server::connection::process_execute(uint16_t stream, temporary_buff future<> cql_server::connection::process_batch(uint16_t stream, temporary_buffer buf) { - assert(0); - return make_ready_future<>(); + if (_version == 1) { + throw exceptions::protocol_exception("BATCH messages are not support in version 1 of the protocol"); + } + + const auto type = read_byte(buf); + const unsigned n = read_unsigned_short(buf); + + std::vector> modifications; + std::vector> values; + + modifications.reserve(n); + values.reserve(n); + + for ([[gnu::unused]] auto i : boost::irange(0u, n)) { + const auto kind = read_byte(buf); + + ::shared_ptr ps; + + switch (kind) { + case 0: { + auto query = read_long_string_view(buf).to_string(); + ps = _server._query_processor.local().get_statement(query, + _client_state); + break; + } + case 1: { + auto id = read_short_bytes(buf); + ps = _server._query_processor.local().get_prepared(id); + if (!ps) { + throw exceptions::prepared_query_not_found_exception(id); + } + break; + } + default: + throw exceptions::protocol_exception( + "Invalid query kind in BATCH messages. Must be 0 or 1 but got " + + std::to_string(int(kind))); + } + + if (dynamic_cast(ps->statement.get()) == nullptr) { + throw exceptions::invalid_request_exception("Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."); + } + + modifications.emplace_back(static_pointer_cast(ps->statement)); + + std::vector tmp; + read_value_view_list(buf, tmp); + + auto stmt = ps->statement; + if (stmt->get_bound_terms() != tmp.size()) { + throw exceptions::invalid_request_exception(sprint("There were %d markers(?) in CQL but %d bound variables", + stmt->get_bound_terms(), tmp.size())); + } + values.emplace_back(std::move(tmp)); + } + + auto& q_state = get_query_state(stream); + auto& query_state = q_state.query_state; + q_state.options = std::make_unique(std::move(*read_options(buf)), std::move(values)); + auto& options = *q_state.options; + + auto batch = ::make_shared(-1, cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none()); + return _server._query_processor.local().process_batch(batch, query_state, options).then([this, stream, batch] (auto msg) { + return this->write_result(stream, msg); + }); } future<> cql_server::connection::process_register(uint16_t stream, temporary_buffer buf)