diff --git a/cql3/Cql.g b/cql3/Cql.g index 0c1b96843f..994aba7d99 100644 --- a/cql3/Cql.g +++ b/cql3/Cql.g @@ -41,6 +41,7 @@ options { #include "cql3/statements/batch_statement.hh" #include "cql3/statements/ks_prop_defs.hh" #include "cql3/selection/raw_selector.hh" +#include "cql3/selection/selectable_with_field_selection.hh" #include "cql3/constants.hh" #include "cql3/operation_impl.hh" #include "cql3/error_listener.hh" @@ -52,6 +53,7 @@ options { #include "cql3/lists.hh" #include "cql3/type_cast.hh" #include "cql3/tuples.hh" +#include "cql3/user_types.hh" #include "cql3/functions/function_name.hh" #include "cql3/functions/function_call.hh" #include "core/sstring.hh" @@ -311,9 +313,7 @@ unaliasedSelector returns [shared_ptr s] | K_TTL '(' c=cident ')' { tmp = make_shared(c, false); } | f=functionName args=selectionFunctionArgs { tmp = ::make_shared(std::move(f), std::move(args)); } ) -#if 0 - ( '.' fi=cident { tmp = new Selectable.WithFieldSelection.Raw(tmp, fi); } )* -#endif + ( '.' fi=cident { tmp = make_shared(std::move(tmp), std::move(fi)); } )* { $s = tmp; } ; @@ -1033,15 +1033,12 @@ collectionLiteral returns [shared_ptr value] | '{' '}' { $value = make_shared(cql3::sets::literal({})); } ; -#if 0 - -usertypeLiteral returns [UserTypes.Literal ut] - @init{ Map m = new HashMap(); } - @after{ $ut = new UserTypes.Literal(m); } +usertypeLiteral returns [shared_ptr ut] + @init{ cql3::user_types::literal::elements_map_type m; } + @after{ $ut = ::make_shared(std::move(m)); } // We don't allow empty literals because that conflicts with sets/maps and is currently useless since we don't allow empty user types - : '{' k1=ident ':' v1=term { m.put(k1, v1); } ( ',' kn=ident ':' vn=term { m.put(kn, vn); } )* '}' + : '{' k1=ident ':' v1=term { m.emplace(std::move(*k1), std::move(v1)); } ( ',' kn=ident ':' vn=term { m.emplace(std::move(*kn), std::move(vn)); } )* '}' ; -#endif tupleLiteral returns [shared_ptr tt] @init{ std::vector> l; } @@ -1052,9 +1049,7 @@ tupleLiteral returns [shared_ptr tt] value returns [::shared_ptr value] : c=constant { $value = c; } | l=collectionLiteral { $value = l; } -#if 0 | u=usertypeLiteral { $value = u; } -#endif | t=tupleLiteral { $value = t; } | K_NULL { $value = cql3::constants::NULL_LITERAL; } | ':' id=ident { $value = new_bind_variables(id); } diff --git a/cql3/UserTypes.java b/cql3/UserTypes.java deleted file mode 100644 index 934344c19d..0000000000 --- a/cql3/UserTypes.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.cql3; - -import java.nio.ByteBuffer; -import java.util.*; - -import org.apache.cassandra.db.marshal.CollectionType; -import org.apache.cassandra.db.marshal.UserType; -import org.apache.cassandra.db.marshal.UTF8Type; -import org.apache.cassandra.exceptions.InvalidRequestException; -import org.apache.cassandra.transport.Server; - -/** - * Static helper methods and classes for user types. - */ -public abstract class UserTypes -{ - private UserTypes() {} - - public static ColumnSpecification fieldSpecOf(ColumnSpecification column, int field) - { - UserType ut = (UserType)column.type; - return new ColumnSpecification(column.ksName, - column.cfName, - new ColumnIdentifier(column.name + "." + UTF8Type.instance.compose(ut.fieldName(field)), true), - ut.fieldType(field)); - } - - public static class Literal implements Term.Raw - { - public final Map entries; - - public Literal(Map entries) - { - this.entries = entries; - } - - public Term prepare(String keyspace, ColumnSpecification receiver) throws InvalidRequestException - { - validateAssignableTo(keyspace, receiver); - - UserType ut = (UserType)receiver.type; - boolean allTerminal = true; - List values = new ArrayList<>(entries.size()); - int foundValues = 0; - for (int i = 0; i < ut.size(); i++) - { - ColumnIdentifier field = new ColumnIdentifier(ut.fieldName(i), UTF8Type.instance); - Term.Raw raw = entries.get(field); - if (raw == null) - raw = Constants.NULL_LITERAL; - else - ++foundValues; - Term value = raw.prepare(keyspace, fieldSpecOf(receiver, i)); - - if (value instanceof Term.NonTerminal) - allTerminal = false; - - values.add(value); - } - if (foundValues != entries.size()) - { - // We had some field that are not part of the type - for (ColumnIdentifier id : entries.keySet()) - if (!ut.fieldNames().contains(id.bytes)) - throw new InvalidRequestException(String.format("Unknown field '%s' in value of user defined type %s", id, ut.getNameAsString())); - } - - DelayedValue value = new DelayedValue(((UserType)receiver.type), values); - return allTerminal ? value.bind(QueryOptions.DEFAULT) : value; - } - - private void validateAssignableTo(String keyspace, ColumnSpecification receiver) throws InvalidRequestException - { - if (!(receiver.type instanceof UserType)) - throw new InvalidRequestException(String.format("Invalid user type literal for %s of type %s", receiver, receiver.type.asCQL3Type())); - - UserType ut = (UserType)receiver.type; - for (int i = 0; i < ut.size(); i++) - { - ColumnIdentifier field = new ColumnIdentifier(ut.fieldName(i), UTF8Type.instance); - Term.Raw value = entries.get(field); - if (value == null) - continue; - - ColumnSpecification fieldSpec = fieldSpecOf(receiver, i); - if (!value.testAssignment(keyspace, fieldSpec).isAssignable()) - throw new InvalidRequestException(String.format("Invalid user type literal for %s: field %s is not of type %s", receiver, field, fieldSpec.type.asCQL3Type())); - } - } - - public AssignmentTestable.TestResult testAssignment(String keyspace, ColumnSpecification receiver) - { - try - { - validateAssignableTo(keyspace, receiver); - return AssignmentTestable.TestResult.WEAKLY_ASSIGNABLE; - } - catch (InvalidRequestException e) - { - return AssignmentTestable.TestResult.NOT_ASSIGNABLE; - } - } - - @Override - public String toString() - { - StringBuilder sb = new StringBuilder(); - sb.append("{"); - Iterator> iter = entries.entrySet().iterator(); - while (iter.hasNext()) - { - Map.Entry entry = iter.next(); - sb.append(entry.getKey()).append(":").append(entry.getValue()); - if (iter.hasNext()) - sb.append(", "); - } - sb.append("}"); - return sb.toString(); - } - } - - // Same purpose than Lists.DelayedValue, except we do handle bind marker in that case - public static class DelayedValue extends Term.NonTerminal - { - private final UserType type; - private final List values; - - public DelayedValue(UserType type, List values) - { - this.type = type; - this.values = values; - } - - public boolean usesFunction(String ksName, String functionName) - { - if (values != null) - for (Term value : values) - if (value != null && value.usesFunction(ksName, functionName)) - return true; - return false; - } - - public boolean containsBindMarker() - { - for (Term t : values) - if (t.containsBindMarker()) - return true; - return false; - } - - public void collectMarkerSpecification(VariableSpecifications boundNames) - { - for (int i = 0; i < type.size(); i++) - values.get(i).collectMarkerSpecification(boundNames); - } - - private ByteBuffer[] bindInternal(QueryOptions options) throws InvalidRequestException - { - int version = options.getProtocolVersion(); - - ByteBuffer[] buffers = new ByteBuffer[values.size()]; - for (int i = 0; i < type.size(); i++) - { - buffers[i] = values.get(i).bindAndGet(options); - // Inside UDT values, we must force the serialization of collections to v3 whatever protocol - // version is in use since we're going to store directly that serialized value. - if (version < Server.VERSION_3 && type.fieldType(i).isCollection() && buffers[i] != null) - buffers[i] = ((CollectionType)type.fieldType(i)).getSerializer().reserializeToV3(buffers[i]); - } - return buffers; - } - - public Constants.Value bind(QueryOptions options) throws InvalidRequestException - { - return new Constants.Value(bindAndGet(options)); - } - - @Override - public ByteBuffer bindAndGet(QueryOptions options) throws InvalidRequestException - { - return UserType.buildValue(bindInternal(options)); - } - } -} diff --git a/cql3/abstract_marker.cc b/cql3/abstract_marker.cc index af365240c8..272f9c4f97 100644 --- a/cql3/abstract_marker.cc +++ b/cql3/abstract_marker.cc @@ -31,7 +31,7 @@ namespace cql3 { -::shared_ptr abstract_marker::raw::prepare(const sstring& keyspace, ::shared_ptr receiver) +::shared_ptr abstract_marker::raw::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { auto receiver_type = ::dynamic_pointer_cast(receiver->type); if (receiver_type == nullptr) { @@ -47,7 +47,7 @@ namespace cql3 { assert(0); } -::shared_ptr abstract_marker::in_raw::prepare(const sstring& keyspace, ::shared_ptr receiver) { +::shared_ptr abstract_marker::in_raw::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { return ::make_shared(_bind_index, make_in_receiver(receiver)); } diff --git a/cql3/abstract_marker.hh b/cql3/abstract_marker.hh index 2907d8e7d5..29ad6d1932 100644 --- a/cql3/abstract_marker.hh +++ b/cql3/abstract_marker.hh @@ -65,9 +65,9 @@ public: : _bind_index{bind_index} { } - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) override; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) override; - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, ::shared_ptr receiver) override { + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) override { return assignment_testable::test_result::WEAKLY_ASSIGNABLE; } @@ -95,7 +95,7 @@ public: } public: - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) override; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) override; }; }; diff --git a/cql3/assignment_testable.hh b/cql3/assignment_testable.hh index f6048eff6a..c4d6703e72 100644 --- a/cql3/assignment_testable.hh +++ b/cql3/assignment_testable.hh @@ -29,6 +29,8 @@ #include #include +class database; + namespace cql3 { class assignment_testable { @@ -52,7 +54,7 @@ public: // Test all elements of toTest for assignment. If all are exact match, return exact match. If any is not assignable, // return not assignable. Otherwise, return weakly assignable. template - static test_result test_all(const sstring& keyspace, ::shared_ptr receiver, + static test_result test_all(database& db, const sstring& keyspace, ::shared_ptr receiver, AssignmentTestablePtrRange&& to_test) { test_result res = test_result::EXACT_MATCH; for (auto&& rt : to_test) { @@ -61,7 +63,7 @@ public: continue; } - test_result t = rt->test_assignment(keyspace, receiver); + test_result t = rt->test_assignment(db, keyspace, receiver); if (t == test_result::NOT_ASSIGNABLE) { return test_result::NOT_ASSIGNABLE; } @@ -81,7 +83,7 @@ public: * Most caller should just call the isAssignable() method on the result, though functions have a use for * testing "strong" equality to decide the most precise overload to pick when multiple could match. */ - virtual test_result test_assignment(const sstring& keyspace, ::shared_ptr receiver) = 0; + virtual test_result test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) = 0; // for error reporting virtual sstring assignment_testable_source_context() const = 0; diff --git a/cql3/attributes.hh b/cql3/attributes.hh index 290d94c3f8..fbcb3211ce 100644 --- a/cql3/attributes.hh +++ b/cql3/attributes.hh @@ -126,9 +126,9 @@ public: ::shared_ptr timestamp; ::shared_ptr time_to_live; - std::unique_ptr prepare(const sstring& ks_name, const sstring& cf_name) { - auto ts = !timestamp ? ::shared_ptr{} : timestamp->prepare(ks_name, timestamp_receiver(ks_name, cf_name)); - auto ttl = !time_to_live ? ::shared_ptr{} : time_to_live->prepare(ks_name, time_to_live_receiver(ks_name, cf_name)); + std::unique_ptr prepare(database& db, const sstring& ks_name, const sstring& cf_name) { + auto ts = !timestamp ? ::shared_ptr{} : timestamp->prepare(db, ks_name, timestamp_receiver(ks_name, cf_name)); + auto ttl = !time_to_live ? ::shared_ptr{} : time_to_live->prepare(db, ks_name, time_to_live_receiver(ks_name, cf_name)); return std::unique_ptr{new attributes{std::move(ts), std::move(ttl)}}; } diff --git a/cql3/column_condition.cc b/cql3/column_condition.cc index 839acdbb4d..681aa7497b 100644 --- a/cql3/column_condition.cc +++ b/cql3/column_condition.cc @@ -58,7 +58,7 @@ void column_condition::collect_marker_specificaton(::shared_ptr -column_condition::raw::prepare(const sstring& keyspace, const column_definition& receiver) { +column_condition::raw::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { if (receiver.type->is_counter()) { throw exceptions::invalid_request_exception("Conditions on counters are not supported"); } @@ -66,16 +66,16 @@ column_condition::raw::prepare(const sstring& keyspace, const column_definition& if (!_collection_element) { if (_op == operator_type::IN) { if (_in_values.empty()) { // ? - return column_condition::in_condition(receiver, _in_marker->prepare(keyspace, receiver.column_specification)); + return column_condition::in_condition(receiver, _in_marker->prepare(db, keyspace, receiver.column_specification)); } std::vector<::shared_ptr> terms; for (auto&& value : _in_values) { - terms.push_back(value->prepare(keyspace, receiver.column_specification)); + terms.push_back(value->prepare(db, keyspace, receiver.column_specification)); } return column_condition::in_condition(receiver, std::move(terms)); } else { - return column_condition::condition(receiver, _value->prepare(keyspace, receiver.column_specification), _op); + return column_condition::condition(receiver, _value->prepare(db, keyspace, receiver.column_specification), _op); } } diff --git a/cql3/column_condition.hh b/cql3/column_condition.hh index 86e21b9b72..f697a479ac 100644 --- a/cql3/column_condition.hh +++ b/cql3/column_condition.hh @@ -758,7 +758,7 @@ public: std::move(collection_element), operator_type::IN); } - ::shared_ptr prepare(const sstring& keyspace, const column_definition& receiver); + ::shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver); }; }; diff --git a/cql3/column_identifier.cc b/cql3/column_identifier.cc index 36a64f4594..631cd9cdd7 100644 --- a/cql3/column_identifier.cc +++ b/cql3/column_identifier.cc @@ -31,7 +31,7 @@ std::ostream& operator<<(std::ostream& out, const column_identifier::raw& id) { } ::shared_ptr -column_identifier::new_selector_factory(schema_ptr schema, std::vector& defs) { +column_identifier::new_selector_factory(database& db, schema_ptr schema, std::vector& defs) { auto def = get_column_definition(schema, *this); if (!def) { throw exceptions::invalid_request_exception(sprint("Undefined name %s in selection clause", _text)); diff --git a/cql3/column_identifier.hh b/cql3/column_identifier.hh index 34afaf46c3..d663edc10b 100644 --- a/cql3/column_identifier.hh +++ b/cql3/column_identifier.hh @@ -127,7 +127,7 @@ public: } #endif - virtual ::shared_ptr new_selector_factory(schema_ptr schema, + virtual ::shared_ptr new_selector_factory(database& db, schema_ptr schema, std::vector& defs) override; /** diff --git a/cql3/constants.cc b/cql3/constants.cc index 191fe40911..60a1d7cdeb 100644 --- a/cql3/constants.cc +++ b/cql3/constants.cc @@ -63,7 +63,7 @@ constants::literal::parsed_value(::shared_ptr validator) } assignment_testable::test_result -constants::literal::test_assignment(const sstring& keyspace, ::shared_ptr receiver) +constants::literal::test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) { auto receiver_type = receiver->type->as_cql3_type(); if (receiver_type->is_collection()) { @@ -127,9 +127,9 @@ constants::literal::test_assignment(const sstring& keyspace, ::shared_ptr -constants::literal::prepare(const sstring& keyspace, ::shared_ptr receiver) +constants::literal::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { - if (!is_assignable(test_assignment(keyspace, receiver))) { + if (!is_assignable(test_assignment(db, keyspace, receiver))) { throw exceptions::invalid_request_exception(sprint("Invalid %s constant (%s) for \"%s\" of type %s", _type, _text, *receiver->name, receiver->type->as_cql3_type()->to_string())); } diff --git a/cql3/constants.hh b/cql3/constants.hh index a53fda5ddc..b6a502b2cf 100644 --- a/cql3/constants.hh +++ b/cql3/constants.hh @@ -90,14 +90,15 @@ public: }; static const ::shared_ptr NULL_VALUE; public: - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) override { - if (!is_assignable(test_assignment(keyspace, receiver))) { + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) override { + if (!is_assignable(test_assignment(db, keyspace, receiver))) { throw exceptions::invalid_request_exception("Invalid null value for counter increment/decrement"); } return NULL_VALUE; } - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, + virtual assignment_testable::test_result test_assignment(database& db, + const sstring& keyspace, ::shared_ptr receiver) override { return receiver->type->is_counter() ? assignment_testable::test_result::NOT_ASSIGNABLE @@ -145,7 +146,7 @@ public: return ::make_shared(type::HEX, text); } - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver); + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver); private: bytes parsed_value(::shared_ptr validator); public: @@ -153,7 +154,7 @@ public: return _text; } - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, ::shared_ptr receiver); + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver); virtual sstring to_string() const override { return _type == type::STRING ? sstring(sprint("'%s'", _text)) : _text; diff --git a/cql3/cql3_type.cc b/cql3/cql3_type.cc index ac6a015a7b..d1782f935a 100644 --- a/cql3/cql3_type.cc +++ b/cql3/cql3_type.cc @@ -14,7 +14,7 @@ public: : _type{type} { } public: - virtual shared_ptr prepare(const sstring& keyspace) { + virtual shared_ptr prepare(database& db, const sstring& keyspace) { return _type; } @@ -61,7 +61,7 @@ public: return true; } - virtual shared_ptr prepare(const sstring& keyspace) override { + virtual shared_ptr prepare(database& db, const sstring& keyspace) override { assert(_values); // "Got null values type for a collection"; if (!_frozen && _values->supports_freezing() && !_values->_frozen) { @@ -78,12 +78,12 @@ public: } if (_kind == &collection_type_impl::kind::list) { - return make_shared(cql3_type(to_string(), list_type_impl::get_instance(_values->prepare(keyspace)->get_type(), !_frozen), false)); + return make_shared(cql3_type(to_string(), list_type_impl::get_instance(_values->prepare(db, keyspace)->get_type(), !_frozen), false)); } else if (_kind == &collection_type_impl::kind::set) { - return make_shared(cql3_type(to_string(), set_type_impl::get_instance(_values->prepare(keyspace)->get_type(), !_frozen), false)); + return make_shared(cql3_type(to_string(), set_type_impl::get_instance(_values->prepare(db, keyspace)->get_type(), !_frozen), false)); } else if (_kind == &collection_type_impl::kind::map) { assert(_keys); // "Got null keys type for a collection"; - return make_shared(cql3_type(to_string(), map_type_impl::get_instance(_keys->prepare(keyspace)->get_type(), _values->prepare(keyspace)->get_type(), !_frozen), false)); + return make_shared(cql3_type(to_string(), map_type_impl::get_instance(_keys->prepare(db, keyspace)->get_type(), _values->prepare(db, keyspace)->get_type(), !_frozen), false)); } abort(); } @@ -122,7 +122,7 @@ public: } _frozen = true; } - virtual shared_ptr prepare(const sstring& keyspace) override { + virtual shared_ptr prepare(database& db, const sstring& keyspace) override { if (!_frozen) { freeze(); } @@ -131,7 +131,7 @@ public: if (t->is_counter()) { throw exceptions::invalid_request_exception("Counters are not allowed inside tuples"); } - ts.push_back(t->prepare(keyspace)->get_type()); + ts.push_back(t->prepare(db, keyspace)->get_type()); } return make_cql3_tuple_type(tuple_type_impl::get_instance(std::move(ts))); } diff --git a/cql3/cql3_type.hh b/cql3/cql3_type.hh index 9047c77774..14db8eba75 100644 --- a/cql3/cql3_type.hh +++ b/cql3/cql3_type.hh @@ -29,6 +29,8 @@ #include #include "enum_set.hh" +class database; + namespace cql3 { class cql3_type final { @@ -53,7 +55,7 @@ public: virtual bool is_counter() const; virtual std::experimental::optional keyspace() const; virtual void freeze(); - virtual shared_ptr prepare(const sstring& keyspace) = 0; + virtual shared_ptr prepare(database& db, const sstring& keyspace) = 0; static shared_ptr from(shared_ptr type); #if 0 public static Raw userType(UTName name) diff --git a/cql3/functions/function_call.hh b/cql3/functions/function_call.hh index 9519fa985f..195ca96a29 100644 --- a/cql3/functions/function_call.hh +++ b/cql3/functions/function_call.hh @@ -57,12 +57,12 @@ public: raw(function_name name, std::vector> terms) : _name(std::move(name)), _terms(std::move(terms)) { } - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) override; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) override; private: // All parameters must be terminal static bytes_opt execute(scalar_function& fun, std::vector> parameters); public: - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, shared_ptr receiver) override; + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) override; virtual sstring to_string() const override; }; }; diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index eeb92061f8..a94950c715 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -92,7 +92,8 @@ functions::get_overload_count(const function_name& name) { } shared_ptr -functions::get(const sstring& keyspace, +functions::get(database& db, + const sstring& keyspace, const function_name& name, const std::vector>& provided_args, const sstring& receiver_ks, @@ -128,13 +129,13 @@ functions::get(const sstring& keyspace, // Fast path if there is only one choice if (candidates.size() == 1) { auto fun = std::move(candidates[0]); - validate_types(keyspace, fun, provided_args, receiver_ks, receiver_cf); + validate_types(db, keyspace, fun, provided_args, receiver_ks, receiver_cf); return fun; } std::vector> compatibles; for (auto&& to_test : candidates) { - auto r = match_arguments(keyspace, to_test, provided_args, receiver_ks, receiver_cf); + auto r = match_arguments(db, keyspace, to_test, provided_args, receiver_ks, receiver_cf); switch (r) { case assignment_testable::test_result::EXACT_MATCH: // We always favor exact matches @@ -186,7 +187,8 @@ functions::find(const function_name& name, const std::vector& arg_typ // This method and matchArguments are somewhat duplicate, but this method allows us to provide more precise errors in the common // case where there is no override for a given function. This is thus probably worth the minor code duplication. void -functions::validate_types(const sstring& keyspace, +functions::validate_types(database& db, + const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, @@ -207,7 +209,7 @@ functions::validate_types(const sstring& keyspace, } auto&& expected = make_arg_spec(receiver_ks, receiver_cf, *fun, i); - if (!is_assignable(provided->test_assignment(keyspace, expected))) { + if (!is_assignable(provided->test_assignment(db, keyspace, expected))) { throw exceptions::invalid_request_exception( sprint("Type error: %s cannot be passed as argument %d of function %s of type %s", provided, i, fun->name(), expected->type->as_cql3_type())); @@ -216,7 +218,7 @@ functions::validate_types(const sstring& keyspace, } assignment_testable::test_result -functions::match_arguments(const sstring& keyspace, +functions::match_arguments(database& db, const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, @@ -234,7 +236,7 @@ functions::match_arguments(const sstring& keyspace, continue; } auto&& expected = make_arg_spec(receiver_ks, receiver_cf, *fun, i); - auto arg_res = provided->test_assignment(keyspace, expected); + auto arg_res = provided->test_assignment(db, keyspace, expected); if (arg_res == assignment_testable::test_result::NOT_ASSIGNABLE) { return assignment_testable::test_result::NOT_ASSIGNABLE; } @@ -340,14 +342,14 @@ function_call::make_terminal(shared_ptr fun, bytes_opt result, seriali } ::shared_ptr -function_call::raw::prepare(const sstring& keyspace, ::shared_ptr receiver) { +function_call::raw::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { std::vector> args; args.reserve(_terms.size()); std::transform(_terms.begin(), _terms.end(), std::back_inserter(args), [] (auto&& x) -> shared_ptr { return x; }); - auto&& fun = functions::functions::get(keyspace, _name, args, receiver->ks_name, receiver->cf_name); + auto&& fun = functions::functions::get(db, keyspace, _name, args, receiver->ks_name, receiver->cf_name); if (!fun) { throw exceptions::invalid_request_exception(sprint("Unknown function %s called", _name)); } @@ -375,7 +377,7 @@ function_call::raw::prepare(const sstring& keyspace, ::shared_ptrprepare(keyspace, functions::make_arg_spec(receiver->ks_name, receiver->cf_name, *scalar_fun, i)); + auto&& t = _terms[i]->prepare(db, keyspace, functions::make_arg_spec(receiver->ks_name, receiver->cf_name, *scalar_fun, i)); if (dynamic_cast(t.get())) { all_terminal = false; } @@ -405,13 +407,13 @@ function_call::raw::execute(scalar_function& fun, std::vector> } assignment_testable::test_result -function_call::raw::test_assignment(const sstring& keyspace, shared_ptr receiver) { +function_call::raw::test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) { // Note: Functions.get() will return null if the function doesn't exist, or throw is no function matching // the arguments can be found. We may get one of those if an undefined/wrong function is used as argument // of another, existing, function. In that case, we return true here because we'll throw a proper exception // later with a more helpful error message that if we were to return false here. try { - auto&& fun = functions::get(keyspace, _name, _terms, receiver->ks_name, receiver->cf_name); + auto&& fun = functions::get(db, keyspace, _name, _terms, receiver->ks_name, receiver->cf_name); if (fun && receiver->type->equals(fun->return_type())) { return assignment_testable::test_result::EXACT_MATCH; } else if (!fun || receiver->type->is_value_compatible_with(*fun->return_type())) { diff --git a/cql3/functions/functions.hh b/cql3/functions/functions.hh index f427dce4bc..2b2fceb529 100644 --- a/cql3/functions/functions.hh +++ b/cql3/functions/functions.hh @@ -58,31 +58,34 @@ public: const function& fun, size_t i); static int get_overload_count(const function_name& name); public: - static shared_ptr get(const sstring& keyspace, + static shared_ptr get(database& db, + const sstring& keyspace, const function_name& name, const std::vector>& provided_args, const sstring& receiver_ks, const sstring& receiver_cf); template - static shared_ptr get(const sstring& keyspace, + static shared_ptr get(database& db, + const sstring& keyspace, const function_name& name, AssignmentTestablePtrRange&& provided_args, const sstring& receiver_ks, const sstring& receiver_cf) { const std::vector> args(std::begin(provided_args), std::end(provided_args)); - return get(keyspace, name, args, receiver_ks, receiver_cf); + return get(db, keyspace, name, args, receiver_ks, receiver_cf); } static std::vector> find(const function_name& name); static shared_ptr find(const function_name& name, const std::vector& arg_types); private: // This method and matchArguments are somewhat duplicate, but this method allows us to provide more precise errors in the common // case where there is no override for a given function. This is thus probably worth the minor code duplication. - static void validate_types(const sstring& keyspace, + static void validate_types(database& db, + const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, const sstring& receiver_cf); - static assignment_testable::test_result match_arguments(const sstring& keyspace, + static assignment_testable::test_result match_arguments(database& db, const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, diff --git a/cql3/lists.cc b/cql3/lists.cc index 606550a814..cc138126ad 100644 --- a/cql3/lists.cc +++ b/cql3/lists.cc @@ -25,15 +25,15 @@ lists::value_spec_of(shared_ptr column) { } shared_ptr -lists::literal::prepare(const sstring& keyspace, shared_ptr receiver) { - validate_assignable_to(keyspace, receiver); +lists::literal::prepare(database& db, const sstring& keyspace, shared_ptr receiver) { + validate_assignable_to(db, keyspace, receiver); auto&& value_spec = value_spec_of(receiver); std::vector> values; values.reserve(_elements.size()); bool all_terminal = true; for (auto rt : _elements) { - auto&& t = rt->prepare(keyspace, value_spec); + auto&& t = rt->prepare(db, keyspace, value_spec); if (t->contains_bind_marker()) { throw exceptions::invalid_request_exception(sprint("Invalid list literal for %s: bind variables are not supported inside collection literals", *receiver->name)); @@ -52,14 +52,14 @@ lists::literal::prepare(const sstring& keyspace, shared_ptr receiver) { +lists::literal::validate_assignable_to(database& db, const sstring keyspace, shared_ptr receiver) { if (!dynamic_pointer_cast(receiver->type)) { throw exceptions::invalid_request_exception(sprint("Invalid list literal for %s of type %s", *receiver->name, *receiver->type->as_cql3_type())); } auto&& value_spec = value_spec_of(receiver); for (auto rt : _elements) { - if (!is_assignable(rt->test_assignment(keyspace, value_spec))) { + if (!is_assignable(rt->test_assignment(db, keyspace, value_spec))) { throw exceptions::invalid_request_exception(sprint("Invalid list literal for %s: value %s is not of type %s", *receiver->name, *rt, *value_spec->type->as_cql3_type())); } @@ -67,7 +67,7 @@ lists::literal::validate_assignable_to(const sstring keyspace, shared_ptr receiver) { +lists::literal::test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) { if (!dynamic_pointer_cast(receiver->type)) { return assignment_testable::test_result::NOT_ASSIGNABLE; } @@ -81,7 +81,7 @@ lists::literal::test_assignment(const sstring& keyspace, shared_ptr> to_test; to_test.reserve(_elements.size()); std::copy(_elements.begin(), _elements.end(), std::back_inserter(to_test)); - return assignment_testable::test_all(keyspace, value_spec, to_test); + return assignment_testable::test_all(db, keyspace, value_spec, to_test); } sstring diff --git a/cql3/lists.hh b/cql3/lists.hh index e7fd3f6164..43787c6cc0 100644 --- a/cql3/lists.hh +++ b/cql3/lists.hh @@ -73,11 +73,11 @@ public: explicit literal(std::vector> elements) : _elements(std::move(elements)) { } - shared_ptr prepare(const sstring& keyspace, shared_ptr receiver); + shared_ptr prepare(database& db, const sstring& keyspace, shared_ptr receiver); private: - void validate_assignable_to(const sstring keyspace, shared_ptr receiver); + void validate_assignable_to(database& db, const sstring keyspace, shared_ptr receiver); public: - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, shared_ptr receiver) override; + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) override; virtual sstring to_string() const override; }; diff --git a/cql3/maps.cc b/cql3/maps.cc index 0d44a70033..e79244ff74 100644 --- a/cql3/maps.cc +++ b/cql3/maps.cc @@ -48,8 +48,8 @@ maps::value_spec_of(column_specification& column) { } ::shared_ptr -maps::literal::prepare(const sstring& keyspace, ::shared_ptr receiver) { - validate_assignable_to(keyspace, *receiver); +maps::literal::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { + validate_assignable_to(db, keyspace, *receiver); auto key_spec = maps::key_spec_of(*receiver); auto value_spec = maps::value_spec_of(*receiver); @@ -57,8 +57,8 @@ maps::literal::prepare(const sstring& keyspace, ::shared_ptrprepare(keyspace, key_spec); - auto v = entry.second->prepare(keyspace, value_spec); + auto k = entry.first->prepare(db, keyspace, key_spec); + auto v = entry.second->prepare(db, keyspace, value_spec); if (k->contains_bind_marker() || v->contains_bind_marker()) { throw exceptions::invalid_request_exception(sprint("Invalid map literal for %s: bind variables are not supported inside collection literals", *receiver->name)); @@ -79,24 +79,24 @@ maps::literal::prepare(const sstring& keyspace, ::shared_ptr(receiver.type)) { throw exceptions::invalid_request_exception(sprint("Invalid map literal for %s of type %s", *receiver.name, *receiver.type->as_cql3_type())); } auto&& key_spec = maps::key_spec_of(receiver); auto&& value_spec = maps::value_spec_of(receiver); for (auto&& entry : entries) { - if (!is_assignable(entry.first->test_assignment(keyspace, key_spec))) { + if (!is_assignable(entry.first->test_assignment(db, keyspace, key_spec))) { throw exceptions::invalid_request_exception(sprint("Invalid map literal for %s: key %s is not of type %s", *receiver.name, *entry.first, *key_spec->type->as_cql3_type())); } - if (!is_assignable(entry.second->test_assignment(keyspace, value_spec))) { + if (!is_assignable(entry.second->test_assignment(db, keyspace, value_spec))) { throw exceptions::invalid_request_exception(sprint("Invalid map literal for %s: value %s is not of type %s", *receiver.name, *entry.second, *value_spec->type->as_cql3_type())); } } } assignment_testable::test_result -maps::literal::test_assignment(const sstring& keyspace, ::shared_ptr receiver) { +maps::literal::test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) { throw std::runtime_error("not implemented"); #if 0 if (!(receiver.type instanceof MapType)) diff --git a/cql3/maps.hh b/cql3/maps.hh index 18701fdcba..9ad8a99f34 100644 --- a/cql3/maps.hh +++ b/cql3/maps.hh @@ -50,11 +50,11 @@ public: literal(const std::vector, ::shared_ptr>>& entries_) : entries{entries_} { } - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) override; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) override; private: - void validate_assignable_to(const sstring& keyspace, column_specification& receiver); + void validate_assignable_to(database& db, const sstring& keyspace, column_specification& receiver); public: - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, ::shared_ptr receiver) override; + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) override; virtual sstring to_string() const override; }; diff --git a/cql3/operation.cc b/cql3/operation.cc index 8251a7c7a9..3543a7706f 100644 --- a/cql3/operation.cc +++ b/cql3/operation.cc @@ -30,7 +30,7 @@ namespace cql3 { shared_ptr -operation::set_element::prepare(const sstring& keyspace, const column_definition& receiver) { +operation::set_element::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { using exceptions::invalid_request_exception; auto rtype = dynamic_pointer_cast(receiver.type); if (!rtype) { @@ -40,14 +40,14 @@ operation::set_element::prepare(const sstring& keyspace, const column_definition } if (&rtype->_kind == &collection_type_impl::kind::list) { - auto&& idx = _selector->prepare(keyspace, lists::index_spec_of(receiver.column_specification)); - auto&& lval = _value->prepare(keyspace, lists::value_spec_of(receiver.column_specification)); + auto&& idx = _selector->prepare(db, keyspace, lists::index_spec_of(receiver.column_specification)); + auto&& lval = _value->prepare(db, keyspace, lists::value_spec_of(receiver.column_specification)); return make_shared(receiver, idx, lval); } else if (&rtype->_kind == &collection_type_impl::kind::set) { throw invalid_request_exception(sprint("Invalid operation (%s) for set column %s", receiver, receiver.name())); } else if (&rtype->_kind == &collection_type_impl::kind::map) { - auto key = _selector->prepare(keyspace, maps::key_spec_of(*receiver.column_specification)); - auto mval = _value->prepare(keyspace, maps::value_spec_of(*receiver.column_specification)); + auto key = _selector->prepare(db, keyspace, maps::key_spec_of(*receiver.column_specification)); + auto mval = _value->prepare(db, keyspace, maps::value_spec_of(*receiver.column_specification)); return make_shared(receiver, key, mval); } abort(); @@ -61,8 +61,8 @@ operation::set_element::is_compatible_with(shared_ptr other) { } shared_ptr -operation::addition::prepare(const sstring& keyspace, const column_definition& receiver) { - auto v = _value->prepare(keyspace, receiver.column_specification); +operation::addition::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { + auto v = _value->prepare(db, keyspace, receiver.column_specification); auto ctype = dynamic_pointer_cast(receiver.type); if (!ctype) { @@ -97,7 +97,7 @@ operation::addition::is_compatible_with(shared_ptr other) { } shared_ptr -operation::subtraction::prepare(const sstring& keyspace, const column_definition& receiver) { +operation::subtraction::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { auto ctype = dynamic_pointer_cast(receiver.type); if (!ctype) { fail(unimplemented::cause::COUNTERS); @@ -113,9 +113,9 @@ operation::subtraction::prepare(const sstring& keyspace, const column_definition } if (&ctype->_kind == &collection_type_impl::kind::list) { - return make_shared(receiver, _value->prepare(keyspace, receiver.column_specification)); + return make_shared(receiver, _value->prepare(db, keyspace, receiver.column_specification)); } else if (&ctype->_kind == &collection_type_impl::kind::set) { - return make_shared(receiver, _value->prepare(keyspace, receiver.column_specification)); + return make_shared(receiver, _value->prepare(db, keyspace, receiver.column_specification)); } else if (&ctype->_kind == &collection_type_impl::kind::map) { auto&& mtype = dynamic_pointer_cast(ctype); // The value for a map subtraction is actually a set @@ -124,7 +124,7 @@ operation::subtraction::prepare(const sstring& keyspace, const column_definition receiver.column_specification->cf_name, receiver.column_specification->name, set_type_impl::get_instance(mtype->get_keys_type(), false)); - return make_shared(receiver, _value->prepare(keyspace, std::move(vr))); + return make_shared(receiver, _value->prepare(db, keyspace, std::move(vr))); } abort(); } @@ -135,7 +135,7 @@ operation::subtraction::is_compatible_with(shared_ptr other) { } shared_ptr -operation::prepend::prepare(const sstring& keyspace, const column_definition& receiver) { +operation::prepend::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { warn(unimplemented::cause::COLLECTIONS); throw exceptions::invalid_request_exception("unimplemented, go away"); // FIXME: @@ -158,8 +158,8 @@ operation::prepend::is_compatible_with(shared_ptr other) { ::shared_ptr -operation::set_value::prepare(const sstring& keyspace, const column_definition& receiver) { - auto v = _value->prepare(keyspace, receiver.column_specification); +operation::set_value::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { + auto v = _value->prepare(db, keyspace, receiver.column_specification); if (receiver.type->is_counter()) { throw exceptions::invalid_request_exception(sprint("Cannot set the value of counter column %s (counters can only be incremented/decremented, not set)", receiver.name_as_text())); @@ -194,7 +194,7 @@ operation::element_deletion::affected_column() { } shared_ptr -operation::element_deletion::prepare(const sstring& keyspace, const column_definition& receiver) { +operation::element_deletion::prepare(database& db, const sstring& keyspace, const column_definition& receiver) { if (!receiver.type->is_collection()) { throw exceptions::invalid_request_exception(sprint("Invalid deletion operation for non collection column %s", receiver.name())); } else if (!receiver.type->is_multi_cell()) { @@ -202,13 +202,13 @@ operation::element_deletion::prepare(const sstring& keyspace, const column_defin } auto ctype = static_pointer_cast(receiver.type); if (&ctype->_kind == &collection_type_impl::kind::list) { - auto&& idx = _element->prepare(keyspace, lists::index_spec_of(receiver.column_specification)); + auto&& idx = _element->prepare(db, keyspace, lists::index_spec_of(receiver.column_specification)); return make_shared(receiver, std::move(idx)); } else if (&ctype->_kind == &collection_type_impl::kind::set) { - auto&& elt = _element->prepare(keyspace, sets::value_spec_of(receiver.column_specification)); + auto&& elt = _element->prepare(db, keyspace, sets::value_spec_of(receiver.column_specification)); return make_shared(receiver, std::move(elt)); } else if (&ctype->_kind == &collection_type_impl::kind::map) { - auto&& key = _element->prepare(keyspace, maps::key_spec_of(*receiver.column_specification)); + auto&& key = _element->prepare(db, keyspace, maps::key_spec_of(*receiver.column_specification)); return make_shared(receiver, std::move(key)); } abort(); diff --git a/cql3/operation.hh b/cql3/operation.hh index 81a1c4c3a7..bb0c859112 100644 --- a/cql3/operation.hh +++ b/cql3/operation.hh @@ -136,7 +136,7 @@ public: * be a true column. * @return the prepared update operation. */ - virtual ::shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) = 0; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) = 0; /** * @return whether this operation can be applied alongside the {@code @@ -172,7 +172,7 @@ public: * @param receiver the "column" this operation applies to. * @return the prepared delete operation. */ - virtual ::shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) = 0; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) = 0; }; class set_value; @@ -185,7 +185,7 @@ public: : _selector(std::move(selector)), _value(std::move(value)) { } - virtual shared_ptr prepare(const sstring& keyspace, const column_definition& receiver); + virtual shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver); #if 0 protected String toString(ColumnSpecification column) { @@ -203,7 +203,7 @@ public: : _value(value) { } - virtual shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) override; + virtual shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) override; #if 0 protected String toString(ColumnSpecification column) @@ -222,7 +222,7 @@ public: : _value(value) { } - virtual shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) override; + virtual shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) override; #if 0 protected String toString(ColumnSpecification column) @@ -241,7 +241,7 @@ public: : _value(std::move(value)) { } - virtual shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) override; + virtual shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) override; #if 0 protected String toString(ColumnSpecification column) @@ -262,7 +262,7 @@ public: : _id(std::move(id)), _element(std::move(element)) { } virtual shared_ptr affected_column() override; - virtual shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) override; + virtual shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) override; }; }; diff --git a/cql3/operation_impl.hh b/cql3/operation_impl.hh index 94c39ce581..0bade532c0 100644 --- a/cql3/operation_impl.hh +++ b/cql3/operation_impl.hh @@ -38,7 +38,7 @@ private: public: set_value(::shared_ptr value) : _value(std::move(value)) {} - virtual ::shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) override; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) override; #if 0 protected String toString(ColumnSpecification column) @@ -62,7 +62,7 @@ public: return _id; } - ::shared_ptr prepare(const sstring& keyspace, const column_definition& receiver) { + ::shared_ptr prepare(database& db, const sstring& keyspace, const column_definition& receiver) { // No validation, deleting a column is always "well typed" return ::make_shared(receiver); } diff --git a/cql3/relation.hh b/cql3/relation.hh index f12d62b8c6..190e99cb18 100644 --- a/cql3/relation.hh +++ b/cql3/relation.hh @@ -122,23 +122,23 @@ public: * @return the Restriction corresponding to this Relation * @throws InvalidRequestException if this Relation is not valid */ - virtual ::shared_ptr to_restriction(schema_ptr schema, ::shared_ptr bound_names) final { + virtual ::shared_ptr to_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) final { if (_relation_type == operator_type::EQ) { - return new_EQ_restriction(schema, bound_names); + return new_EQ_restriction(db, schema, bound_names); } else if (_relation_type == operator_type::LT) { - return new_slice_restriction(schema, bound_names, statements::bound::END, false); + return new_slice_restriction(db, schema, bound_names, statements::bound::END, false); } else if (_relation_type == operator_type::LTE) { - return new_slice_restriction(schema, bound_names, statements::bound::END, true); + return new_slice_restriction(db, schema, bound_names, statements::bound::END, true); } else if (_relation_type == operator_type::GTE) { - return new_slice_restriction(schema, bound_names, statements::bound::START, true); + return new_slice_restriction(db, schema, bound_names, statements::bound::START, true); } else if (_relation_type == operator_type::GT) { - return new_slice_restriction(schema, bound_names, statements::bound::START, false); + return new_slice_restriction(db, schema, bound_names, statements::bound::START, false); } else if (_relation_type == operator_type::IN) { - return new_IN_restriction(schema, bound_names); + return new_IN_restriction(db, schema, bound_names); } else if (_relation_type == operator_type::CONTAINS) { - return new_contains_restriction(schema, bound_names, false); + return new_contains_restriction(db, schema, bound_names, false); } else if (_relation_type == operator_type::CONTAINS_KEY) { - return new_contains_restriction(schema, bound_names, true); + return new_contains_restriction(db, schema, bound_names, true); } else { throw exceptions::invalid_request_exception(sprint("Unsupported \"!=\" relation: %s", to_string())); } @@ -158,7 +158,7 @@ public: * @return a new EQ restriction instance. * @throws InvalidRequestException if the relation cannot be converted into an EQ restriction. */ - virtual ::shared_ptr new_EQ_restriction(schema_ptr schema, + virtual ::shared_ptr new_EQ_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) = 0; /** @@ -169,7 +169,7 @@ public: * @return a new IN restriction instance * @throws InvalidRequestException if the relation cannot be converted into an IN restriction. */ - virtual ::shared_ptr new_IN_restriction(schema_ptr schema, + virtual ::shared_ptr new_IN_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) = 0; /** @@ -182,7 +182,7 @@ public: * @return a new slice restriction instance * @throws InvalidRequestException if the Relation is not valid */ - virtual ::shared_ptr new_slice_restriction(schema_ptr schema, + virtual ::shared_ptr new_slice_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names, statements::bound bound, bool inclusive) = 0; @@ -196,7 +196,7 @@ public: * @return a new Contains ::shared_ptr instance * @throws InvalidRequestException if the Relation is not valid */ - virtual ::shared_ptr new_contains_restriction(schema_ptr schema, + virtual ::shared_ptr new_contains_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names, bool isKey) = 0; protected: @@ -213,6 +213,7 @@ protected: */ virtual ::shared_ptr to_term(const std::vector<::shared_ptr>& receivers, ::shared_ptr raw, + database& db, const sstring& keyspace, ::shared_ptr boundNames) = 0; @@ -228,11 +229,12 @@ protected: */ std::vector<::shared_ptr> to_terms(const std::vector<::shared_ptr>& receivers, const std::vector<::shared_ptr>& raws, + database& db, const sstring& keyspace, ::shared_ptr boundNames) { std::vector<::shared_ptr> terms; for (auto&& r : raws) { - terms.emplace_back(to_term(receivers, r, keyspace, boundNames)); + terms.emplace_back(to_term(receivers, r, db, keyspace, boundNames)); } return terms; } diff --git a/cql3/restrictions/statement_restrictions.hh b/cql3/restrictions/statement_restrictions.hh index b6d3ea37d3..eafa9697f7 100644 --- a/cql3/restrictions/statement_restrictions.hh +++ b/cql3/restrictions/statement_restrictions.hh @@ -89,7 +89,8 @@ public: , _nonprimary_key_restrictions(::make_shared(schema)) { } - statement_restrictions(schema_ptr schema, + statement_restrictions(database& db, + schema_ptr schema, const std::vector<::shared_ptr>& where_clause, ::shared_ptr bound_names, bool selects_only_static_columns, @@ -105,7 +106,7 @@ public: */ if (!where_clause.empty()) { for (auto&& relation : where_clause) { - add_restriction(relation->to_restriction(schema, bound_names)); + add_restriction(relation->to_restriction(db, schema, bound_names)); } } diff --git a/cql3/selection/FieldSelector.java b/cql3/selection/FieldSelector.java deleted file mode 100644 index 76dbb22f58..0000000000 --- a/cql3/selection/FieldSelector.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.cql3.selection; - -import java.nio.ByteBuffer; - -import org.apache.cassandra.cql3.selection.Selection.ResultSetBuilder; -import org.apache.cassandra.db.marshal.AbstractType; -import org.apache.cassandra.db.marshal.UTF8Type; -import org.apache.cassandra.db.marshal.UserType; -import org.apache.cassandra.exceptions.InvalidRequestException; - -final class FieldSelector extends Selector -{ - private final UserType type; - private final int field; - private final Selector selected; - - public static Factory newFactory(final UserType type, final int field, final Selector.Factory factory) - { - return new Factory() - { - protected String getColumnName() - { - return String.format("%s.%s", - factory.getColumnName(), - UTF8Type.instance.getString(type.fieldName(field))); - } - - protected AbstractType getReturnType() - { - return type.fieldType(field); - } - - public Selector newInstance() throws InvalidRequestException - { - return new FieldSelector(type, field, factory.newInstance()); - } - - public boolean isAggregateSelectorFactory() - { - return factory.isAggregateSelectorFactory(); - } - }; - } - - public boolean isAggregate() - { - return false; - } - - public void addInput(int protocolVersion, ResultSetBuilder rs) throws InvalidRequestException - { - selected.addInput(protocolVersion, rs); - } - - public ByteBuffer getOutput(int protocolVersion) throws InvalidRequestException - { - ByteBuffer value = selected.getOutput(protocolVersion); - if (value == null) - return null; - ByteBuffer[] buffers = type.split(value); - return field < buffers.length ? buffers[field] : null; - } - - public AbstractType getType() - { - return type.fieldType(field); - } - - public void reset() - { - selected.reset(); - } - - @Override - public String toString() - { - return String.format("%s.%s", selected, UTF8Type.instance.getString(type.fieldName(field))); - } - - private FieldSelector(UserType type, int field, Selector selected) - { - this.type = type; - this.field = field; - this.selected = selected; - } -} \ No newline at end of file diff --git a/cql3/selection/field_selector.hh b/cql3/selection/field_selector.hh new file mode 100644 index 0000000000..137e84e5af --- /dev/null +++ b/cql3/selection/field_selector.hh @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Copyright 2015 Cloudius Systems + * + * Modified by Cloudius Systems + */ + +#pragma once + +#include "selector.hh" +#include "types.hh" + +namespace cql3 { + +namespace selection { + +class field_selector : public selector { + user_type _type; + size_t _field; + shared_ptr _selected; +public: + static shared_ptr new_factory(user_type type, size_t field, shared_ptr factory) { + struct field_selector_factory : selector::factory { + user_type _type; + size_t _field; + shared_ptr _factory; + + field_selector_factory(user_type type, size_t field, shared_ptr factory) + : _type(std::move(type)), _field(field), _factory(std::move(factory)) { + } + + virtual sstring column_name() override { + auto&& name = _type->field_name(_field); + auto sname = sstring(reinterpret_cast(name.begin()), name.size()); + return sprint("%s.%s", _factory->column_name(), sname); + } + + virtual data_type get_return_type() override { + return _type->field_type(_field); + } + + shared_ptr new_instance() override { + return make_shared(_type, _field, _factory->new_instance()); + } + + bool is_aggregate_selector_factory() override { + return _factory->is_aggregate_selector_factory(); + } + }; + return make_shared(std::move(type), field, std::move(factory)); + } + + virtual bool is_aggregate() override { + return false; + } + + virtual void add_input(serialization_format sf, result_set_builder& rs) override { + _selected->add_input(sf, rs); + } + + virtual bytes_opt get_output(serialization_format sf) override { + auto&& value = _selected->get_output(sf); + if (!value) { + return std::experimental::nullopt; + } + auto&& buffers = _type->split(*value); + bytes_opt ret; + if (_field < buffers.size() && buffers[_field]) { + ret = to_bytes(*buffers[_field]); + } + return ret; + } + + virtual data_type get_type() override { + return _type->field_type(_field); + } + + virtual void reset() { + _selected->reset(); + } + + virtual sstring assignment_testable_source_context() const override { + auto&& name = _type->field_name(_field); + auto sname = sstring(reinterpret_cast(name.begin(), name.size())); + return sprint("%s.%s", _selected, sname); + } + + field_selector(user_type type, size_t field, shared_ptr selected) + : _type(std::move(type)), _field(field), _selected(std::move(selected)) { + } +}; + +} +} diff --git a/cql3/selection/selectable.cc b/cql3/selection/selectable.cc index db5011aab8..98a0e002b8 100644 --- a/cql3/selection/selectable.cc +++ b/cql3/selection/selectable.cc @@ -3,6 +3,8 @@ */ #include "selectable.hh" +#include "selectable_with_field_selection.hh" +#include "field_selector.hh" #include "writetime_or_ttl.hh" #include "selector_factories.hh" #include "cql3/functions/functions.hh" @@ -14,7 +16,7 @@ namespace cql3 { namespace selection { shared_ptr -selectable::writetime_or_ttl::new_selector_factory(schema_ptr s, std::vector& defs) { +selectable::writetime_or_ttl::new_selector_factory(database& db, schema_ptr s, std::vector& defs) { auto&& def = s->get_column_definition(_id->name()); if (!def) { throw exceptions::invalid_request_exception(sprint("Undefined name %s in selection clause", _id)); @@ -44,11 +46,11 @@ selectable::writetime_or_ttl::raw::processes_selection() const { } shared_ptr -selectable::with_function::new_selector_factory(schema_ptr s, std::vector& defs) { - auto&& factories = selector_factories::create_factories_and_collect_column_definitions(_args, s, defs); +selectable::with_function::new_selector_factory(database& db, schema_ptr s, std::vector& defs) { + auto&& factories = selector_factories::create_factories_and_collect_column_definitions(_args, db, s, defs); // resolve built-in functions before user defined functions - auto&& fun = functions::functions::get(s->ks_name, _function_name, factories->new_instances(), s->ks_name, s->cf_name); + auto&& fun = functions::functions::get(db, s->ks_name, _function_name, factories->new_instances(), s->ks_name, s->cf_name); if (!fun) { throw exceptions::invalid_request_exception(sprint("Unknown function '%s'", _function_name)); } @@ -74,6 +76,39 @@ selectable::with_function::raw::processes_selection() const { return true; } +shared_ptr +selectable::with_field_selection::new_selector_factory(database& db, schema_ptr s, std::vector& defs) { + auto&& factory = _selected->new_selector_factory(db, s, defs); + auto&& type = factory->new_instance()->get_type(); + auto&& ut = dynamic_pointer_cast(std::move(type)); + if (!ut) { + throw exceptions::invalid_request_exception( + sprint("Invalid field selection: %s of type %s is not a user type", + "FIXME: selectable" /* FIMXME: _selected */, ut->as_cql3_type())); + } + for (size_t i = 0; i < ut->size(); ++i) { + if (ut->field_name(i) != _field->bytes_) { + continue; + } + return field_selector::new_factory(std::move(ut), i, std::move(factory)); + } + throw exceptions::invalid_request_exception(sprint("%s of type %s has no field %s", + "FIXME: selectable" /* FIXME: _selected */, ut->as_cql3_type(), _field)); +} + +shared_ptr +selectable::with_field_selection::raw::prepare(schema_ptr s) { + // static_pointer_cast<> needed due to lack of covariant return type + // support with smart pointers + return make_shared(_selected->prepare(s), + static_pointer_cast(_field->prepare(s))); +} + +bool +selectable::with_field_selection::raw::processes_selection() const { + return true; +} + } } diff --git a/cql3/selection/selectable.hh b/cql3/selection/selectable.hh index 465e85bea3..8debaf504f 100644 --- a/cql3/selection/selectable.hh +++ b/cql3/selection/selectable.hh @@ -54,7 +54,7 @@ import org.apache.commons.lang3.text.StrBuilder; class selectable { public: virtual ~selectable() {} - virtual ::shared_ptr new_selector_factory(schema_ptr schema, std::vector& defs) = 0; + virtual ::shared_ptr new_selector_factory(database& db, schema_ptr schema, std::vector& defs) = 0; protected: static size_t add_and_get_index(const column_definition& def, std::vector& defs) { auto i = std::find(defs.begin(), defs.end(), &def); @@ -81,71 +81,7 @@ public: class with_function; -#if 0 - public static class WithFieldSelection extends Selectable - { - public final Selectable selected; - public final ColumnIdentifier field; - - public WithFieldSelection(Selectable selected, ColumnIdentifier field) - { - this.selected = selected; - this.field = field; - } - - @Override - public String toString() - { - return String.format("%s.%s", selected, field); - } - - public Selector.Factory newSelectorFactory(CFMetaData cfm, - List defs) throws InvalidRequestException - { - Selector.Factory factory = selected.newSelectorFactory(cfm, defs); - AbstractType type = factory.newInstance().getType(); - if (!(type instanceof UserType)) - throw new InvalidRequestException( - String.format("Invalid field selection: %s of type %s is not a user type", - selected, - type.asCQL3Type())); - - UserType ut = (UserType) type; - for (int i = 0; i < ut.size(); i++) - { - if (!ut.fieldName(i).equals(field.bytes)) - continue; - return FieldSelector.newFactory(ut, i, factory); - } - throw new InvalidRequestException(String.format("%s of type %s has no field %s", - selected, - type.asCQL3Type(), - field)); - } - - public static class Raw implements Selectable.Raw - { - private final Selectable.Raw selected; - private final ColumnIdentifier.Raw field; - - public Raw(Selectable.Raw selected, ColumnIdentifier.Raw field) - { - this.selected = selected; - this.field = field; - } - - public WithFieldSelection prepare(CFMetaData cfm) - { - return new WithFieldSelection(selected.prepare(cfm), field.prepare(cfm)); - } - - public boolean processesSelection() - { - return true; - } - } - } -#endif + class with_field_selection; }; class selectable::with_function : public selectable { @@ -168,7 +104,7 @@ public: } #endif - virtual shared_ptr new_selector_factory(schema_ptr s, std::vector& defs) override; + virtual shared_ptr new_selector_factory(database& db, schema_ptr s, std::vector& defs) override; class raw : public selectable::raw { functions::function_name _function_name; std::vector> _args; diff --git a/cql3/selection/selectable_with_field_selection.hh b/cql3/selection/selectable_with_field_selection.hh new file mode 100644 index 0000000000..f3c5502490 --- /dev/null +++ b/cql3/selection/selectable_with_field_selection.hh @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2015 Cloudius Systems + * + * Modified by Cloudius Systems + */ + + +#pragma once + +#include "selectable.hh" +#include "cql3/column_identifier.hh" + +namespace cql3 { + +namespace selection { + +class selectable::with_field_selection : public selectable { +public: + shared_ptr _selected; + shared_ptr _field; +public: + with_field_selection(shared_ptr selected, shared_ptr field) + : _selected(std::move(selected)), _field(std::move(field)) { + } + +#if 0 + @Override + public String toString() + { + return String.format("%s.%s", selected, field); + } +#endif + + virtual shared_ptr new_selector_factory(database& db, schema_ptr s, std::vector& defs) override; + + class raw : public selectable::raw { + shared_ptr _selected; + shared_ptr _field; + public: + raw(shared_ptr selected, shared_ptr field) + : _selected(std::move(selected)), _field(std::move(field)) { + } + virtual shared_ptr prepare(schema_ptr s) override; + virtual bool processes_selection() const override; + }; +}; + +} + +} diff --git a/cql3/selection/selection.cc b/cql3/selection/selection.cc index 62eb503be4..d8ac331c9f 100644 --- a/cql3/selection/selection.cc +++ b/cql3/selection/selection.cc @@ -198,12 +198,12 @@ uint32_t selection::add_column_for_ordering(const column_definition& c) { return _columns.size() - 1; } -::shared_ptr selection::from_selectors(schema_ptr schema, const std::vector<::shared_ptr>& raw_selectors) { +::shared_ptr selection::from_selectors(database& db, schema_ptr schema, const std::vector<::shared_ptr>& raw_selectors) { std::vector defs; ::shared_ptr factories = selector_factories::create_factories_and_collect_column_definitions( - raw_selector::to_selectables(raw_selectors, schema), schema, defs); + raw_selector::to_selectables(raw_selectors, schema), db, schema, defs); auto metadata = collect_metadata(schema, raw_selectors, *factories); if (processes_selection(raw_selectors)) { diff --git a/cql3/selection/selection.hh b/cql3/selection/selection.hh index 493622f2e1..cc3978ae14 100644 --- a/cql3/selection/selection.hh +++ b/cql3/selection/selection.hh @@ -167,7 +167,7 @@ private: static std::vector<::shared_ptr> collect_metadata(schema_ptr schema, const std::vector<::shared_ptr>& raw_selectors, const selector_factories& factories); public: - static ::shared_ptr from_selectors(schema_ptr schema, const std::vector<::shared_ptr>& raw_selectors); + static ::shared_ptr from_selectors(database& db, schema_ptr schema, const std::vector<::shared_ptr>& raw_selectors); virtual std::unique_ptr new_selectors() = 0; diff --git a/cql3/selection/selector.hh b/cql3/selection/selector.hh index a300ddd967..0bc9cd46c4 100644 --- a/cql3/selection/selector.hh +++ b/cql3/selection/selector.hh @@ -87,7 +87,7 @@ public: */ virtual void reset() = 0; - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, ::shared_ptr receiver) override { + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, ::shared_ptr receiver) override { if (receiver->type == get_type()) { return assignment_testable::test_result::EXACT_MATCH; } else if (receiver->type->is_value_compatible_with(*get_type())) { diff --git a/cql3/selection/selector_factories.cc b/cql3/selection/selector_factories.cc index 775f008b5d..0fabf03f09 100644 --- a/cql3/selection/selector_factories.cc +++ b/cql3/selection/selector_factories.cc @@ -30,7 +30,8 @@ namespace cql3 { namespace selection { -selector_factories::selector_factories(std::vector<::shared_ptr> selectables, schema_ptr schema, +selector_factories::selector_factories(std::vector<::shared_ptr> selectables, + database& db, schema_ptr schema, std::vector& defs) : _contains_write_time_factory(false) , _contains_ttl_factory(false) @@ -39,7 +40,7 @@ selector_factories::selector_factories(std::vector<::shared_ptr> sel _factories.reserve(selectables.size()); for (auto&& selectable : selectables) { - auto factory = selectable->new_selector_factory(schema, defs); + auto factory = selectable->new_selector_factory(db, schema, defs); _contains_write_time_factory |= factory->is_write_time_selector_factory(); _contains_ttl_factory |= factory->is_ttl_selector_factory(); if (factory->is_aggregate_selector_factory()) { diff --git a/cql3/selection/selector_factories.hh b/cql3/selection/selector_factories.hh index 411964667c..76e1cab5b1 100644 --- a/cql3/selection/selector_factories.hh +++ b/cql3/selection/selector_factories.hh @@ -69,12 +69,13 @@ public: */ static ::shared_ptr create_factories_and_collect_column_definitions( std::vector<::shared_ptr> selectables, - schema_ptr schema, + database& db, schema_ptr schema, std::vector& defs) { - return ::make_shared(std::move(selectables), std::move(schema), defs); + return ::make_shared(std::move(selectables), db, std::move(schema), defs); } - selector_factories(std::vector<::shared_ptr> selectables, schema_ptr schema, std::vector& defs); + selector_factories(std::vector<::shared_ptr> selectables, + database& db, schema_ptr schema, std::vector& defs); public: bool uses_function(const sstring& ks_name, const sstring& function_name) const; diff --git a/cql3/selection/writetime_or_ttl.hh b/cql3/selection/writetime_or_ttl.hh index 7ec51e008d..f5556bb749 100644 --- a/cql3/selection/writetime_or_ttl.hh +++ b/cql3/selection/writetime_or_ttl.hh @@ -49,7 +49,7 @@ public: } #endif - virtual shared_ptr new_selector_factory(schema_ptr s, std::vector& defs) override; + virtual shared_ptr new_selector_factory(database& db, schema_ptr s, std::vector& defs) override; class raw : public selectable::raw { shared_ptr _id; diff --git a/cql3/sets.cc b/cql3/sets.cc index 9d7fcc848c..8af5e792b7 100644 --- a/cql3/sets.cc +++ b/cql3/sets.cc @@ -16,8 +16,8 @@ sets::value_spec_of(shared_ptr column) { } shared_ptr -sets::literal::prepare(const sstring& keyspace, shared_ptr receiver) { - validate_assignable_to(keyspace, receiver); +sets::literal::prepare(database& db, const sstring& keyspace, shared_ptr receiver) { + validate_assignable_to(db, keyspace, receiver); // We've parsed empty maps as a set literal to break the ambiguity so // handle that case now @@ -33,7 +33,7 @@ sets::literal::prepare(const sstring& keyspace, shared_ptr bool all_terminal = true; for (shared_ptr rt : _elements) { - auto t = rt->prepare(keyspace, value_spec); + auto t = rt->prepare(db, keyspace, value_spec); if (t->contains_bind_marker()) { throw exceptions::invalid_request_exception(sprint("Invalid set literal for %s: bind variables are not supported inside collection literals", *receiver->name)); @@ -56,7 +56,7 @@ sets::literal::prepare(const sstring& keyspace, shared_ptr } void -sets::literal::validate_assignable_to(const sstring& keyspace, shared_ptr receiver) { +sets::literal::validate_assignable_to(database& db, const sstring& keyspace, shared_ptr receiver) { if (!dynamic_pointer_cast(receiver->type)) { // We've parsed empty maps as a set literal to break the ambiguity so // handle that case now @@ -69,14 +69,14 @@ sets::literal::validate_assignable_to(const sstring& keyspace, shared_ptr rt : _elements) { - if (!is_assignable(rt->test_assignment(keyspace, value_spec))) { + if (!is_assignable(rt->test_assignment(db, keyspace, value_spec))) { throw exceptions::invalid_request_exception(sprint("Invalid set literal for %s: value %s is not of type %s", *receiver->name, *rt, *value_spec->type->as_cql3_type())); } } } assignment_testable::test_result -sets::literal::test_assignment(const sstring& keyspace, shared_ptr receiver) { +sets::literal::test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) { if (!dynamic_pointer_cast(receiver->type)) { // We've parsed empty maps as a set literal to break the ambiguity so handle that case now if (dynamic_pointer_cast(receiver->type) && _elements.empty()) { @@ -94,7 +94,7 @@ sets::literal::test_assignment(const sstring& keyspace, shared_ptr> to_test(_elements.begin(), _elements.end()); - return assignment_testable::test_all(keyspace, value_spec, to_test); + return assignment_testable::test_all(db, keyspace, value_spec, to_test); } sstring diff --git a/cql3/sets.hh b/cql3/sets.hh index 69f7dbfd5e..a3c2d32795 100644 --- a/cql3/sets.hh +++ b/cql3/sets.hh @@ -71,10 +71,10 @@ public: explicit literal(std::vector> elements) : _elements(std::move(elements)) { } - shared_ptr prepare(const sstring& keyspace, shared_ptr receiver); - void validate_assignable_to(const sstring& keyspace, shared_ptr receiver); + shared_ptr prepare(database& db, const sstring& keyspace, shared_ptr receiver); + void validate_assignable_to(database& db, const sstring& keyspace, shared_ptr receiver); assignment_testable::test_result - test_assignment(const sstring& keyspace, shared_ptr receiver); + test_assignment(database& db, const sstring& keyspace, shared_ptr receiver); virtual sstring to_string() const override; }; diff --git a/cql3/single_column_relation.cc b/cql3/single_column_relation.cc index f69f57aca1..fdd56da654 100644 --- a/cql3/single_column_relation.cc +++ b/cql3/single_column_relation.cc @@ -33,20 +33,21 @@ namespace cql3 { ::shared_ptr single_column_relation::to_term(const std::vector<::shared_ptr>& receivers, ::shared_ptr raw, + database& db, const sstring& keyspace, ::shared_ptr bound_names) { // TODO: optimize vector away, accept single column_specification assert(receivers.size() == 1); - auto term = raw->prepare(keyspace, receivers[0]); + auto term = raw->prepare(db, keyspace, receivers[0]); term->collect_marker_specification(bound_names); return term; } ::shared_ptr -single_column_relation::new_EQ_restriction(schema_ptr schema, ::shared_ptr bound_names) { +single_column_relation::new_EQ_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) { const column_definition& column_def = to_column_definition(schema, _entity); if (!_map_key) { - auto term = to_term(to_receivers(schema, column_def), _value, schema->ks_name, bound_names); + auto term = to_term(to_receivers(schema, column_def), _value, db, schema->ks_name, bound_names); return ::make_shared(column_def, std::move(term)); } fail(unimplemented::cause::COLLECTIONS); @@ -59,10 +60,10 @@ single_column_relation::new_EQ_restriction(schema_ptr schema, ::shared_ptr -single_column_relation::new_IN_restriction(schema_ptr schema, ::shared_ptr bound_names) { +single_column_relation::new_IN_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) { const column_definition& column_def = to_column_definition(schema, _entity); auto receivers = to_receivers(schema, column_def); - auto terms = to_terms(receivers, _in_values, schema->ks_name, bound_names); + auto terms = to_terms(receivers, _in_values, db, schema->ks_name, bound_names); if (terms.empty()) { fail(unimplemented::cause::COLLECTIONS); #if 0 diff --git a/cql3/single_column_relation.hh b/cql3/single_column_relation.hh index 8b9c69dbdf..d4b2db42b5 100644 --- a/cql3/single_column_relation.hh +++ b/cql3/single_column_relation.hh @@ -95,7 +95,7 @@ public: } protected: virtual ::shared_ptr to_term(const std::vector<::shared_ptr>& receivers, - ::shared_ptr raw, const sstring& keyspace, + ::shared_ptr raw, database& db, const sstring& keyspace, ::shared_ptr bound_names) override; #if 0 @@ -124,22 +124,22 @@ protected: } protected: - virtual ::shared_ptr new_EQ_restriction(schema_ptr schema, + virtual ::shared_ptr new_EQ_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names); - virtual ::shared_ptr new_IN_restriction(schema_ptr schema, + virtual ::shared_ptr new_IN_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names) override; - virtual ::shared_ptr new_slice_restriction(schema_ptr schema, + virtual ::shared_ptr new_slice_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names, statements::bound bound, bool inclusive) override { auto&& column_def = to_column_definition(schema, _entity); - auto term = to_term(to_receivers(schema, column_def), _value, schema->ks_name, std::move(bound_names)); + auto term = to_term(to_receivers(schema, column_def), _value, db, schema->ks_name, std::move(bound_names)); return ::make_shared(column_def, bound, inclusive, std::move(term)); } - virtual shared_ptr new_contains_restriction(schema_ptr schema, + virtual shared_ptr new_contains_restriction(database& db, schema_ptr schema, ::shared_ptr bound_names, bool is_key) override { throw std::runtime_error("not implemented"); diff --git a/cql3/statements/batch_statement.hh b/cql3/statements/batch_statement.hh index 0e9caf8158..3a53901fbb 100644 --- a/cql3/statements/batch_statement.hh +++ b/cql3/statements/batch_statement.hh @@ -350,7 +350,7 @@ public: statements.push_back(parsed->prepare(db, bound_names)); } - auto&& prep_attrs = _attrs->prepare("[batch]", "[batch]"); + auto&& prep_attrs = _attrs->prepare(db, "[batch]", "[batch]"); prep_attrs->collect_marker_specification(bound_names); batch_statement batch_statement_(bound_names->size(), _type, std::move(statements), std::move(prep_attrs)); diff --git a/cql3/statements/delete_statement.cc b/cql3/statements/delete_statement.cc index df9b024e5d..a359c7f492 100644 --- a/cql3/statements/delete_statement.cc +++ b/cql3/statements/delete_statement.cc @@ -52,7 +52,7 @@ void delete_statement::add_update_for_key(mutation& m, const exploded_clustering } ::shared_ptr -delete_statement::parsed::prepare_internal(schema_ptr schema, ::shared_ptr bound_names, +delete_statement::parsed::prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs) { auto stmt = ::make_shared(statement_type::DELETE, bound_names->size(), schema, std::move(attrs)); @@ -70,12 +70,12 @@ delete_statement::parsed::prepare_internal(schema_ptr schema, ::shared_ptrname_as_text())); } - auto&& op = deletion->prepare(schema->ks_name, *def); + auto&& op = deletion->prepare(db, schema->ks_name, *def); op->collect_marker_specification(bound_names); stmt->add_operation(op); } - stmt->process_where_clause(_where_clause, std::move(bound_names)); + stmt->process_where_clause(db, _where_clause, std::move(bound_names)); return stmt; } diff --git a/cql3/statements/delete_statement.hh b/cql3/statements/delete_statement.hh index ad05133af6..2558d67711 100644 --- a/cql3/statements/delete_statement.hh +++ b/cql3/statements/delete_statement.hh @@ -99,7 +99,7 @@ public: , _where_clause(std::move(where_clause)) { } protected: - virtual ::shared_ptr prepare_internal(schema_ptr schema, + virtual ::shared_ptr prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs); }; }; diff --git a/cql3/statements/modification_statement.cc b/cql3/statements/modification_statement.cc index 0f21325370..6f0f288358 100644 --- a/cql3/statements/modification_statement.cc +++ b/cql3/statements/modification_statement.cc @@ -411,7 +411,7 @@ modification_statement::add_key_value(const column_definition& def, ::shared_ptr } void -modification_statement::process_where_clause(std::vector where_clause, ::shared_ptr names) { +modification_statement::process_where_clause(database& db, std::vector where_clause, ::shared_ptr names) { for (auto&& relation : where_clause) { if (relation->is_multi_column()) { throw exceptions::invalid_request_exception(sprint("Multi-column relations cannot be used in WHERE clauses for UPDATE and DELETE statements: %s", relation->to_string())); @@ -430,7 +430,7 @@ modification_statement::process_where_clause(std::vector where_cla if (def->is_primary_key()) { if (rel->is_EQ() || (def->is_partition_key() && rel->is_IN())) { - add_key_values(*def, rel->to_restriction(s, std::move(names))); + add_key_values(*def, rel->to_restriction(db, s, std::move(names))); } else { throw exceptions::invalid_request_exception(sprint("Invalid operator %s for PRIMARY KEY part %s", rel->get_operator(), def->name_as_text())); } @@ -451,10 +451,10 @@ modification_statement::parsed::prepare(database& db) { modification_statement::parsed::prepare(database& db, ::shared_ptr bound_names) { schema_ptr schema = validation::validate_column_family(db, keyspace(), column_family()); - auto prepared_attributes = _attrs->prepare(keyspace(), column_family()); + auto prepared_attributes = _attrs->prepare(db, keyspace(), column_family()); prepared_attributes->collect_marker_specification(bound_names); - ::shared_ptr stmt = prepare_internal(schema, bound_names, std::move(prepared_attributes)); + ::shared_ptr stmt = prepare_internal(db, schema, bound_names, std::move(prepared_attributes)); if (_if_not_exists || _if_exists || !_conditions.empty()) { if (stmt->is_counter()) { @@ -482,7 +482,7 @@ modification_statement::parsed::prepare(database& db, ::shared_ptrprepare(keyspace(), *def); + auto condition = entry.second->prepare(db, keyspace(), *def); condition->collect_marker_specificaton(bound_names); if (def->is_primary_key()) { diff --git a/cql3/statements/modification_statement.hh b/cql3/statements/modification_statement.hh index 3a4411c88c..c2bc136311 100644 --- a/cql3/statements/modification_statement.hh +++ b/cql3/statements/modification_statement.hh @@ -223,7 +223,7 @@ private: public: void add_key_value(const column_definition& def, ::shared_ptr value); - void process_where_clause(std::vector where_clause, ::shared_ptr names); + void process_where_clause(database& db, std::vector where_clause, ::shared_ptr names); std::vector build_partition_keys(const query_options& options); private: @@ -437,7 +437,7 @@ public: virtual ::shared_ptr prepare(database& db) override; ::shared_ptr prepare(database& db, ::shared_ptr bound_names);; protected: - virtual ::shared_ptr prepare_internal(schema_ptr schema, + virtual ::shared_ptr prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs) = 0; }; }; diff --git a/cql3/statements/select_statement.cc b/cql3/statements/select_statement.cc index 15e3def4ba..774db72173 100644 --- a/cql3/statements/select_statement.cc +++ b/cql3/statements/select_statement.cc @@ -331,9 +331,9 @@ select_statement::raw_statement::prepare(database& db) { auto selection = _select_clause.empty() ? selection::selection::wildcard(schema) - : selection::selection::from_selectors(schema, _select_clause); + : selection::selection::from_selectors(db, schema, _select_clause); - auto restrictions = prepare_restrictions(schema, bound_names, selection); + auto restrictions = prepare_restrictions(db, schema, bound_names, selection); if (_parameters->is_distinct()) { validate_distinct_selection(schema, selection, restrictions); @@ -361,18 +361,18 @@ select_statement::raw_statement::prepare(database& db) { std::move(restrictions), is_reversed_, std::move(ordering_comparator), - prepare_limit(bound_names)); + prepare_limit(db, bound_names)); return ::make_shared(std::move(stmt), std::move(*bound_names)); } ::shared_ptr -select_statement::raw_statement::prepare_restrictions(schema_ptr schema, +select_statement::raw_statement::prepare_restrictions(database& db, schema_ptr schema, ::shared_ptr bound_names, ::shared_ptr selection) { try { - return ::make_shared(schema, std::move(_where_clause), bound_names, + return ::make_shared(db, schema, std::move(_where_clause), bound_names, selection->contains_only_static_columns(), selection->contains_a_collection()); } catch (const exceptions::unrecognized_entity_exception& e) { if (contains_alias(e.entity)) { @@ -384,12 +384,12 @@ select_statement::raw_statement::prepare_restrictions(schema_ptr schema, /** Returns a ::shared_ptr for the limit or null if no limit is set */ ::shared_ptr -select_statement::raw_statement::prepare_limit(::shared_ptr bound_names) { +select_statement::raw_statement::prepare_limit(database& db, ::shared_ptr bound_names) { if (!_limit) { return {}; } - auto prep_limit = _limit->prepare(keyspace(), limit_receiver()); + auto prep_limit = _limit->prepare(db, keyspace(), limit_receiver()); prep_limit->collect_marker_specification(bound_names); return prep_limit; } diff --git a/cql3/statements/select_statement.hh b/cql3/statements/select_statement.hh index 31dc3413c1..8eabaf0f04 100644 --- a/cql3/statements/select_statement.hh +++ b/cql3/statements/select_statement.hh @@ -456,12 +456,13 @@ public: virtual ::shared_ptr prepare(database& db) override; private: ::shared_ptr prepare_restrictions( + database& db, schema_ptr schema, ::shared_ptr bound_names, ::shared_ptr selection); /** Returns a ::shared_ptr for the limit or null if no limit is set */ - ::shared_ptr prepare_limit(::shared_ptr bound_names); + ::shared_ptr prepare_limit(database& db, ::shared_ptr bound_names); static void verify_ordering_is_allowed(::shared_ptr restrictions); diff --git a/cql3/statements/update_statement.cc b/cql3/statements/update_statement.cc index 2da26c4de3..e71a9ef52e 100644 --- a/cql3/statements/update_statement.cc +++ b/cql3/statements/update_statement.cc @@ -88,7 +88,7 @@ void update_statement::add_update_for_key(mutation& m, const exploded_clustering } ::shared_ptr -update_statement::parsed_insert::prepare_internal(schema_ptr schema, +update_statement::parsed_insert::prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs) { auto stmt = ::make_shared(statement_type::INSERT, bound_names->size(), schema, std::move(attrs)); @@ -123,11 +123,11 @@ update_statement::parsed_insert::prepare_internal(schema_ptr schema, auto&& value = _column_values[i]; if (def->is_primary_key()) { - auto t = value->prepare(keyspace(), def->column_specification); + auto t = value->prepare(db, keyspace(), def->column_specification); t->collect_marker_specification(bound_names); stmt->add_key_value(*def, std::move(t)); } else { - auto operation = operation::set_value(value).prepare(keyspace(), *def); + auto operation = operation::set_value(value).prepare(db, keyspace(), *def); operation->collect_marker_specification(bound_names); stmt->add_operation(std::move(operation)); }; @@ -136,7 +136,7 @@ update_statement::parsed_insert::prepare_internal(schema_ptr schema, } ::shared_ptr -update_statement::parsed_update::prepare_internal(schema_ptr schema, +update_statement::parsed_update::prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs) { auto stmt = ::make_shared(statement_type::UPDATE, bound_names->size(), schema, std::move(attrs)); @@ -148,7 +148,7 @@ update_statement::parsed_update::prepare_internal(schema_ptr schema, throw exceptions::invalid_request_exception(sprint("Unknown identifier %s", *entry.first)); } - auto operation = entry.second->prepare(keyspace(), *def); + auto operation = entry.second->prepare(db, keyspace(), *def); operation->collect_marker_specification(bound_names); if (def->is_primary_key()) { @@ -157,7 +157,7 @@ update_statement::parsed_update::prepare_internal(schema_ptr schema, stmt->add_operation(std::move(operation)); } - stmt->process_where_clause(_where_clause, bound_names); + stmt->process_where_clause(db, _where_clause, bound_names); return stmt; } diff --git a/cql3/statements/update_statement.hh b/cql3/statements/update_statement.hh index d6608cb632..fb03aeecb0 100644 --- a/cql3/statements/update_statement.hh +++ b/cql3/statements/update_statement.hh @@ -99,7 +99,7 @@ public: , _column_values{std::move(column_values)} { } - virtual ::shared_ptr prepare_internal(schema_ptr schema, + virtual ::shared_ptr prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs) override; }; @@ -129,7 +129,7 @@ public: , _where_clause(std::move(where_clause)) { } protected: - virtual ::shared_ptr prepare_internal(schema_ptr schema, + virtual ::shared_ptr prepare_internal(database& db, schema_ptr schema, ::shared_ptr bound_names, std::unique_ptr attrs); }; }; diff --git a/cql3/term.hh b/cql3/term.hh index b8f94c28e6..44691bab2a 100644 --- a/cql3/term.hh +++ b/cql3/term.hh @@ -115,7 +115,7 @@ public: * case this RawTerm describe a list index or a map key, etc... * @return the prepared term. */ - virtual ::shared_ptr prepare(const sstring& keyspace, ::shared_ptr receiver) = 0; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) = 0; virtual sstring to_string() const = 0; @@ -131,7 +131,7 @@ public: class multi_column_raw : public virtual raw { public: - virtual ::shared_ptr prepare(const sstring& keyspace, const std::vector>& receiver) = 0; + virtual ::shared_ptr prepare(database& db, const sstring& keyspace, const std::vector>& receiver) = 0; }; }; diff --git a/cql3/tuples.hh b/cql3/tuples.hh index 5e3ebd6339..21fec34877 100644 --- a/cql3/tuples.hh +++ b/cql3/tuples.hh @@ -47,12 +47,12 @@ public: literal(std::vector> elements) : _elements(std::move(elements)) { } - virtual shared_ptr prepare(const sstring& keyspace, shared_ptr receiver) override { - validate_assignable_to(keyspace, receiver); + virtual shared_ptr prepare(database& db, const sstring& keyspace, shared_ptr receiver) override { + validate_assignable_to(db, keyspace, receiver); std::vector> values; bool all_terminal = true; for (size_t i = 0; i < _elements.size(); ++i) { - auto&& value = _elements[i]->prepare(keyspace, component_spec_of(receiver, i)); + auto&& value = _elements[i]->prepare(db, keyspace, component_spec_of(receiver, i)); if (dynamic_pointer_cast(value)) { all_terminal = false; } @@ -66,7 +66,7 @@ public: } } - virtual shared_ptr prepare(const sstring& keyspace, const std::vector>& receivers) override { + virtual shared_ptr prepare(database& db, const sstring& keyspace, const std::vector>& receivers) override { if (_elements.size() != receivers.size()) { throw exceptions::invalid_request_exception(sprint("Expected %d elements in value tuple, but got %d: %s", receivers.size(), _elements.size(), *this)); } @@ -75,7 +75,7 @@ public: std::vector types; bool all_terminal = true; for (size_t i = 0; i < _elements.size(); ++i) { - auto&& t = _elements[i]->prepare(keyspace, receivers[i]); + auto&& t = _elements[i]->prepare(db, keyspace, receivers[i]); if (dynamic_pointer_cast(t)) { all_terminal = false; } @@ -91,7 +91,7 @@ public: } private: - void validate_assignable_to(const sstring& keyspace, shared_ptr receiver) { + void validate_assignable_to(database& db, const sstring& keyspace, shared_ptr receiver) { auto tt = dynamic_pointer_cast(receiver->type); if (!tt) { throw exceptions::invalid_request_exception(sprint("Invalid tuple type literal for %s of type %s", receiver->name, receiver->type->as_cql3_type())); @@ -104,15 +104,15 @@ public: auto&& value = _elements[i]; auto&& spec = component_spec_of(receiver, i); - if (!assignment_testable::is_assignable(value->test_assignment(keyspace, spec))) { + if (!assignment_testable::is_assignable(value->test_assignment(db, keyspace, spec))) { throw exceptions::invalid_request_exception(sprint("Invalid tuple literal for %s: component %d is not of type %s", receiver->name, i, spec->type->as_cql3_type())); } } } public: - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, shared_ptr receiver) override { + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) override { try { - validate_assignable_to(keyspace, receiver); + validate_assignable_to(db, keyspace, receiver); return assignment_testable::test_result::WEAKLY_ASSIGNABLE; } catch (exceptions::invalid_request_exception e) { return assignment_testable::test_result::NOT_ASSIGNABLE; diff --git a/cql3/type_cast.hh b/cql3/type_cast.hh index 21a76a3481..43be9f3ca7 100644 --- a/cql3/type_cast.hh +++ b/cql3/type_cast.hh @@ -35,24 +35,24 @@ public: : _type(std::move(type)), _term(std::move(term)) { } - virtual shared_ptr prepare(const sstring& keyspace, shared_ptr receiver) override { - if (!is_assignable(_term->test_assignment(keyspace, casted_spec_of(keyspace, receiver)))) { + virtual shared_ptr prepare(database& db, const sstring& keyspace, shared_ptr receiver) override { + if (!is_assignable(_term->test_assignment(db, keyspace, casted_spec_of(db, keyspace, receiver)))) { throw exceptions::invalid_request_exception(sprint("Cannot cast value %s to type %s", _term, _type)); } - if (!is_assignable(test_assignment(keyspace, receiver))) { + if (!is_assignable(test_assignment(db, keyspace, receiver))) { throw exceptions::invalid_request_exception(sprint("Cannot assign value %s to %s of type %s", *this, receiver->name, receiver->type->as_cql3_type())); } - return _term->prepare(keyspace, receiver); + return _term->prepare(db, keyspace, receiver); } private: - shared_ptr casted_spec_of(const sstring& keyspace, shared_ptr receiver) { + shared_ptr casted_spec_of(database& db, const sstring& keyspace, shared_ptr receiver) { return make_shared(receiver->ks_name, receiver->cf_name, - make_shared(to_string(), true), _type->prepare(keyspace)->get_type()); + make_shared(to_string(), true), _type->prepare(db, keyspace)->get_type()); } public: - virtual assignment_testable::test_result test_assignment(const sstring& keyspace, shared_ptr receiver) override { + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) override { try { - auto&& casted_type = _type->prepare(keyspace)->get_type(); + auto&& casted_type = _type->prepare(db, keyspace)->get_type(); if (receiver->type->equals(casted_type)) { return assignment_testable::test_result::EXACT_MATCH; } else if (receiver->type->is_value_compatible_with(*casted_type)) { diff --git a/cql3/user_types.hh b/cql3/user_types.hh new file mode 100644 index 0000000000..c51351f926 --- /dev/null +++ b/cql3/user_types.hh @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Modified by Cloudius Systems + * + * Copyright 2015 Cloudius Systems + */ + +#pragma once + +#include "column_specification.hh" +#include "term.hh" +#include "column_identifier.hh" +#include "constants.hh" +#include "to_string.hh" +#include +#include +#include + +namespace cql3 { + +/** + * Static helper methods and classes for user types. + */ +class user_types { + user_types() = delete; +public: + static shared_ptr field_spec_of(shared_ptr column, size_t field) { + auto&& ut = static_pointer_cast(column->type); + auto&& name = ut->field_name(field); + auto&& sname = sstring(reinterpret_cast(name.data()), name.size()); + return make_shared( + column->ks_name, + column->cf_name, + make_shared(column->name->to_string() + "." + sname, true), + ut->field_type(field)); + } + + class literal : public term::raw { + public: + using elements_map_type = std::unordered_map>; + elements_map_type _entries; + + literal(elements_map_type entries) + : _entries(std::move(entries)) { + } + + virtual shared_ptr prepare(database& db, const sstring& keyspace, shared_ptr receiver) override { + validate_assignable_to(db, keyspace, receiver); + auto&& ut = static_pointer_cast(receiver->type); + bool all_terminal = true; + std::vector> values; + values.reserve(_entries.size()); + size_t found_values = 0; + for (size_t i = 0; i < ut->size(); ++i) { + auto&& field = column_identifier(to_bytes(ut->field_name(i)), utf8_type); + auto iraw = _entries.find(field); + shared_ptr raw; + if (iraw == _entries.end()) { + raw = cql3::constants::NULL_LITERAL; + } else { + raw = iraw->second; + ++found_values; + } + auto&& value = raw->prepare(db, keyspace, field_spec_of(receiver, i)); + + if (dynamic_cast(value.get())) { + all_terminal = false; + } + + values.push_back(std::move(value)); + } + if (found_values != _entries.size()) { + // We had some field that are not part of the type + for (auto&& id_val : _entries) { + auto&& id = id_val.first; + if (!boost::range::count(ut->field_names(), id.bytes_)) { + throw exceptions::invalid_request_exception(sprint("Unknown field '%s' in value of user defined type %s", id, ut->get_name_as_string())); + } + } + } + + delayed_value value(ut, values); + if (all_terminal) { + return value.bind(query_options::DEFAULT); + } else { + return make_shared(std::move(value)); + } + } + private: + void validate_assignable_to(database& db, const sstring& keyspace, shared_ptr receiver) { + auto&& ut = dynamic_pointer_cast(receiver->type); + if (!ut) { + throw exceptions::invalid_request_exception(sprint("Invalid user type literal for %s of type %s", receiver->name, receiver->type->as_cql3_type())); + } + + for (size_t i = 0; i < ut->size(); i++) { + column_identifier field(to_bytes(ut->field_name(i)), utf8_type); + if (_entries.count(field) == 0) { + continue; + } + shared_ptr value = _entries[field]; + auto&& field_spec = field_spec_of(receiver, i); + if (!assignment_testable::is_assignable(value->test_assignment(db, keyspace, field_spec))) { + throw exceptions::invalid_request_exception(sprint("Invalid user type literal for %s: field %s is not of type %s", receiver->name, field, field_spec->type->as_cql3_type())); + } + } + } + public: + virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) override { + try { + validate_assignable_to(db, keyspace, receiver); + return assignment_testable::test_result::WEAKLY_ASSIGNABLE; + } catch (exceptions::invalid_request_exception& e) { + return assignment_testable::test_result::NOT_ASSIGNABLE; + } + } + + virtual sstring assignment_testable_source_context() const override { + return to_string(); + } + + virtual sstring to_string() const override { + auto kv_to_str = [] (auto&& kv) { return sprint("%s:%s", kv.first, kv.second); }; + return sprint("{%s}", ::join(", ", _entries | boost::adaptors::transformed(kv_to_str))); + } + }; + + // Same purpose than Lists.DelayedValue, except we do handle bind marker in that case + class delayed_value : public non_terminal { + user_type _type; + std::vector> _values; + public: + delayed_value(user_type type, std::vector> values) + : _type(std::move(type)), _values(std::move(values)) { + } + virtual bool uses_function(const sstring& ks_name, const sstring& function_name) const override { + return boost::algorithm::any_of(_values, + std::bind(&term::uses_function, std::placeholders::_1, std::cref(ks_name), std::cref(function_name))); + } + virtual bool contains_bind_marker() const override { + return boost::algorithm::any_of(_values, std::mem_fn(&term::contains_bind_marker)); + } + + virtual void collect_marker_specification(shared_ptr bound_names) { + for (auto&& v : _values) { + v->collect_marker_specification(bound_names); + } + } + private: + std::vector bind_internal(const query_options& options) { + auto sf = options.get_serialization_format(); + std::vector buffers; + for (size_t i = 0; i < _type->size(); ++i) { + buffers.push_back(_values[i]->bind_and_get(options)); + // Inside UDT values, we must force the serialization of collections to v3 whatever protocol + // version is in use since we're going to store directly that serialized value. + if (sf != serialization_format::use_32_bit() && _type->field_type(i)->is_collection() && buffers.back()) { + auto&& ctype = static_pointer_cast(_type->field_type(i)); + buffers.back() = ctype->reserialize(sf, serialization_format::use_32_bit(), bytes_view(*buffers.back())); + } + } + return buffers; + } + public: + virtual shared_ptr bind(const query_options& options) override { + return ::make_shared(bind_and_get(options)); + } + + virtual bytes_opt bind_and_get(const query_options& options) override { + return user_type_impl::build_value(bind_internal(options)); + } + }; +}; + +} diff --git a/tests/urchin/cql_query_test.cc b/tests/urchin/cql_query_test.cc index 88013a25e5..6904bcfb0e 100644 --- a/tests/urchin/cql_query_test.cc +++ b/tests/urchin/cql_query_test.cc @@ -771,3 +771,36 @@ SEASTAR_TEST_CASE(test_tuples) { }); }); } + +SEASTAR_TEST_CASE(test_user_type) { + auto make_user_type = [] { + return user_type_impl::get_instance("ks", to_bytes("ut1"), + {to_bytes("my_int"), to_bytes("my_bigint"), to_bytes("my_text")}, + {int32_type, long_type, utf8_type}); + }; + return do_with_cql_env([make_user_type] (cql_test_env& e) { + return e.create_table([make_user_type] (auto ks_name) { + // CQL: "create table cf (id int primary key, t tuple); + return schema({}, ks_name, "cf", + {{"id", int32_type}}, {}, {{"t", make_user_type()}}, {}, utf8_type); + }).then([&e] { + return e.execute_cql("insert into cf (id, t) values (1, (1001, 2001, 'abc1'));").discard_result(); + }).then([&e] { + return e.execute_cql("select t.my_int, t.my_bigint, t.my_text from cf where id = 1;"); + }).then([&e] (shared_ptr msg) { + assert_that(msg).is_rows() + .with_rows({ + {int32_type->decompose(int32_t(1001)), long_type->decompose(int64_t(2001)), utf8_type->decompose(sstring("abc1"))}, + }); + }).then([&e] { + return e.execute_cql("update cf set t = { my_int: 1002, my_bigint: 2002, my_text: 'abc2' } where id = 1;").discard_result(); + }).then([&e] { + return e.execute_cql("select t.my_int, t.my_bigint, t.my_text from cf where id = 1;"); + }).then([&e] (shared_ptr msg) { + assert_that(msg).is_rows() + .with_rows({ + {int32_type->decompose(int32_t(1002)), long_type->decompose(int64_t(2002)), utf8_type->decompose(sstring("abc2"))}, + }); + }); + }); +} diff --git a/types.cc b/types.cc index 2c91e875b6..b8f84867e1 100644 --- a/types.cc +++ b/types.cc @@ -1484,13 +1484,17 @@ list_type_impl::cql3_type_name() const { return sprint("list<%s>", _elements->as_cql3_type()); } -tuple_type_impl::tuple_type_impl(std::vector types) - : abstract_type(make_name(types)), _types(std::move(types)) { +tuple_type_impl::tuple_type_impl(sstring name, std::vector types) + : abstract_type(std::move(name)), _types(std::move(types)) { for (auto& t : _types) { t = t->freeze(); } } +tuple_type_impl::tuple_type_impl(std::vector types) + : tuple_type_impl(make_name(types), std::move(types)) { +} + shared_ptr tuple_type_impl::get_instance(std::vector types) { return ::make_shared(std::move(types)); @@ -1631,6 +1635,29 @@ tuple_type_impl::make_name(const std::vector& types) { return sprint("tuple<%s>", ::join(", ", types | boost::adaptors::transformed(std::mem_fn(&abstract_type::name)))); } +sstring +user_type_impl::get_name_as_string() const { + return boost::any_cast(utf8_type->compose(_name)); +} + +shared_ptr +user_type_impl::as_cql3_type() { + throw "not yet"; +} + +sstring +user_type_impl::make_name(sstring keyspace, bytes name, std::vector field_names, std::vector field_types) { + std::ostringstream os; + os << "(" << keyspace << "," << to_hex(name); + for (size_t i = 0; i < field_names.size(); ++i) { + os << ","; + os << to_hex(field_names[i]) << ":"; + os << field_types[i]->name(); // FIXME: ignore frozen<> + } + os << ")"; + return os.str(); +} + thread_local const shared_ptr int32_type(make_shared()); thread_local const shared_ptr long_type(make_shared()); thread_local const shared_ptr ascii_type(make_shared("ascii", [] { return cql3::cql3_type::ascii; })); diff --git a/types.hh b/types.hh index 6cb6a3a90f..6cacd2f517 100644 --- a/types.hh +++ b/types.hh @@ -839,6 +839,7 @@ protected: static boost::iterator_range make_range(bytes_view v) { return { tuple_deserializing_iterator::start(v), tuple_deserializing_iterator::finish(v) }; } + tuple_type_impl(sstring name, std::vector types); public: using native_type = std::vector; tuple_type_impl(std::vector types); @@ -889,3 +890,31 @@ private: // FIXME: conflicts with another tuple_type using db_tuple_type = shared_ptr; + +class user_type_impl : public tuple_type_impl { +public: + const sstring _keyspace; + const bytes _name; +private: + std::vector _field_names; +public: + user_type_impl(sstring keyspace, bytes name, std::vector field_names, std::vector field_types) + : tuple_type_impl(make_name(keyspace, name, field_names, field_types), field_types) + , _keyspace(keyspace) + , _name(name) + , _field_names(field_names) { + } + static shared_ptr get_instance(sstring keyspace, bytes name, std::vector field_names, std::vector field_types) { + return ::make_shared(std::move(keyspace), std::move(name), std::move(field_names), std::move(field_types)); + } + data_type field_type(size_t i) const { return type(i); } + const std::vector& field_types() const { return _types; } + bytes_view field_name(size_t i) const { return _field_names[i]; } + const std::vector& field_names() const { return _field_names; } + sstring get_name_as_string() const; + virtual shared_ptr as_cql3_type() override; +private: + static sstring make_name(sstring keyspace, bytes name, std::vector field_names, std::vector field_types); +}; + +using user_type = shared_ptr;