cql3: convert user_types.hh to C++

This commit is contained in:
Avi Kivity
2015-04-20 15:55:26 +03:00
parent 55ec6bb923
commit c01515d291
2 changed files with 191 additions and 201 deletions

View File

@@ -1,201 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.cql3;
import java.nio.ByteBuffer;
import java.util.*;
import org.apache.cassandra.db.marshal.CollectionType;
import org.apache.cassandra.db.marshal.UserType;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.Server;
/**
* Static helper methods and classes for user types.
*/
public abstract class UserTypes
{
private UserTypes() {}
public static ColumnSpecification fieldSpecOf(ColumnSpecification column, int field)
{
UserType ut = (UserType)column.type;
return new ColumnSpecification(column.ksName,
column.cfName,
new ColumnIdentifier(column.name + "." + UTF8Type.instance.compose(ut.fieldName(field)), true),
ut.fieldType(field));
}
public static class Literal implements Term.Raw
{
public final Map<ColumnIdentifier, Term.Raw> entries;
public Literal(Map<ColumnIdentifier, Term.Raw> entries)
{
this.entries = entries;
}
public Term prepare(String keyspace, ColumnSpecification receiver) throws InvalidRequestException
{
validateAssignableTo(keyspace, receiver);
UserType ut = (UserType)receiver.type;
boolean allTerminal = true;
List<Term> values = new ArrayList<>(entries.size());
int foundValues = 0;
for (int i = 0; i < ut.size(); i++)
{
ColumnIdentifier field = new ColumnIdentifier(ut.fieldName(i), UTF8Type.instance);
Term.Raw raw = entries.get(field);
if (raw == null)
raw = Constants.NULL_LITERAL;
else
++foundValues;
Term value = raw.prepare(keyspace, fieldSpecOf(receiver, i));
if (value instanceof Term.NonTerminal)
allTerminal = false;
values.add(value);
}
if (foundValues != entries.size())
{
// We had some field that are not part of the type
for (ColumnIdentifier id : entries.keySet())
if (!ut.fieldNames().contains(id.bytes))
throw new InvalidRequestException(String.format("Unknown field '%s' in value of user defined type %s", id, ut.getNameAsString()));
}
DelayedValue value = new DelayedValue(((UserType)receiver.type), values);
return allTerminal ? value.bind(QueryOptions.DEFAULT) : value;
}
private void validateAssignableTo(String keyspace, ColumnSpecification receiver) throws InvalidRequestException
{
if (!(receiver.type instanceof UserType))
throw new InvalidRequestException(String.format("Invalid user type literal for %s of type %s", receiver, receiver.type.asCQL3Type()));
UserType ut = (UserType)receiver.type;
for (int i = 0; i < ut.size(); i++)
{
ColumnIdentifier field = new ColumnIdentifier(ut.fieldName(i), UTF8Type.instance);
Term.Raw value = entries.get(field);
if (value == null)
continue;
ColumnSpecification fieldSpec = fieldSpecOf(receiver, i);
if (!value.testAssignment(keyspace, fieldSpec).isAssignable())
throw new InvalidRequestException(String.format("Invalid user type literal for %s: field %s is not of type %s", receiver, field, fieldSpec.type.asCQL3Type()));
}
}
public AssignmentTestable.TestResult testAssignment(String keyspace, ColumnSpecification receiver)
{
try
{
validateAssignableTo(keyspace, receiver);
return AssignmentTestable.TestResult.WEAKLY_ASSIGNABLE;
}
catch (InvalidRequestException e)
{
return AssignmentTestable.TestResult.NOT_ASSIGNABLE;
}
}
@Override
public String toString()
{
StringBuilder sb = new StringBuilder();
sb.append("{");
Iterator<Map.Entry<ColumnIdentifier, Term.Raw>> iter = entries.entrySet().iterator();
while (iter.hasNext())
{
Map.Entry<ColumnIdentifier, Term.Raw> entry = iter.next();
sb.append(entry.getKey()).append(":").append(entry.getValue());
if (iter.hasNext())
sb.append(", ");
}
sb.append("}");
return sb.toString();
}
}
// Same purpose than Lists.DelayedValue, except we do handle bind marker in that case
public static class DelayedValue extends Term.NonTerminal
{
private final UserType type;
private final List<Term> values;
public DelayedValue(UserType type, List<Term> values)
{
this.type = type;
this.values = values;
}
public boolean usesFunction(String ksName, String functionName)
{
if (values != null)
for (Term value : values)
if (value != null && value.usesFunction(ksName, functionName))
return true;
return false;
}
public boolean containsBindMarker()
{
for (Term t : values)
if (t.containsBindMarker())
return true;
return false;
}
public void collectMarkerSpecification(VariableSpecifications boundNames)
{
for (int i = 0; i < type.size(); i++)
values.get(i).collectMarkerSpecification(boundNames);
}
private ByteBuffer[] bindInternal(QueryOptions options) throws InvalidRequestException
{
int version = options.getProtocolVersion();
ByteBuffer[] buffers = new ByteBuffer[values.size()];
for (int i = 0; i < type.size(); i++)
{
buffers[i] = values.get(i).bindAndGet(options);
// Inside UDT values, we must force the serialization of collections to v3 whatever protocol
// version is in use since we're going to store directly that serialized value.
if (version < Server.VERSION_3 && type.fieldType(i).isCollection() && buffers[i] != null)
buffers[i] = ((CollectionType)type.fieldType(i)).getSerializer().reserializeToV3(buffers[i]);
}
return buffers;
}
public Constants.Value bind(QueryOptions options) throws InvalidRequestException
{
return new Constants.Value(bindAndGet(options));
}
@Override
public ByteBuffer bindAndGet(QueryOptions options) throws InvalidRequestException
{
return UserType.buildValue(bindInternal(options));
}
}
}

191
cql3/user_types.hh Normal file
View File

@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Modified by Cloudius Systems
*
* Copyright 2015 Cloudius Systems
*/
#pragma once
#include "column_specification.hh"
#include "term.hh"
#include "column_identifier.hh"
#include "constants.hh"
#include "to_string.hh"
#include <boost/range/adaptor/transformed.hpp>
#include <boost/range/algorithm/count.hpp>
#include <boost/algorithm/cxx11/any_of.hpp>
namespace cql3 {
/**
* Static helper methods and classes for user types.
*/
class user_types {
user_types() = delete;
public:
static shared_ptr<column_specification> field_spec_of(shared_ptr<column_specification> column, size_t field) {
auto&& ut = static_pointer_cast<user_type_impl>(column->type);
auto&& name = ut->field_name(field);
auto&& sname = sstring(reinterpret_cast<const char*>(name.data()), name.size());
return make_shared<column_specification>(
column->ks_name,
column->cf_name,
make_shared<column_identifier>(column->name->to_string() + "." + sname, true),
ut->field_type(field));
}
class literal : public term::raw {
public:
using elements_map_type = std::unordered_map<column_identifier, shared_ptr<term::raw>>;
elements_map_type _entries;
literal(elements_map_type entries)
: _entries(std::move(entries)) {
}
virtual shared_ptr<term> prepare(database& db, const sstring& keyspace, shared_ptr<column_specification> receiver) override {
validate_assignable_to(db, keyspace, receiver);
auto&& ut = static_pointer_cast<user_type_impl>(receiver->type);
bool all_terminal = true;
std::vector<shared_ptr<term>> values;
values.reserve(_entries.size());
size_t found_values = 0;
for (size_t i = 0; i < ut->size(); ++i) {
auto&& field = column_identifier(to_bytes(ut->field_name(i)), utf8_type);
auto iraw = _entries.find(field);
shared_ptr<term::raw> raw;
if (iraw == _entries.end()) {
raw = cql3::constants::NULL_LITERAL;
} else {
raw = iraw->second;
++found_values;
}
auto&& value = raw->prepare(db, keyspace, field_spec_of(receiver, i));
if (dynamic_cast<non_terminal*>(value.get())) {
all_terminal = false;
}
values.push_back(std::move(value));
}
if (found_values != _entries.size()) {
// We had some field that are not part of the type
for (auto&& id_val : _entries) {
auto&& id = id_val.first;
if (!boost::range::count(ut->field_names(), id.bytes_)) {
throw exceptions::invalid_request_exception(sprint("Unknown field '%s' in value of user defined type %s", id, ut->get_name_as_string()));
}
}
}
delayed_value value(ut, values);
if (all_terminal) {
return value.bind(query_options::DEFAULT);
} else {
return make_shared(std::move(value));
}
}
private:
void validate_assignable_to(database& db, const sstring& keyspace, shared_ptr<column_specification> receiver) {
auto&& ut = dynamic_pointer_cast<user_type_impl>(receiver->type);
if (!ut) {
throw exceptions::invalid_request_exception(sprint("Invalid user type literal for %s of type %s", receiver->name, receiver->type->as_cql3_type()));
}
for (size_t i = 0; i < ut->size(); i++) {
column_identifier field(to_bytes(ut->field_name(i)), utf8_type);
if (_entries.count(field) == 0) {
continue;
}
shared_ptr<term::raw> value = _entries[field];
auto&& field_spec = field_spec_of(receiver, i);
if (!assignment_testable::is_assignable(value->test_assignment(db, keyspace, field_spec))) {
throw exceptions::invalid_request_exception(sprint("Invalid user type literal for %s: field %s is not of type %s", receiver->name, field, field_spec->type->as_cql3_type()));
}
}
}
public:
virtual assignment_testable::test_result test_assignment(database& db, const sstring& keyspace, shared_ptr<column_specification> receiver) override {
try {
validate_assignable_to(db, keyspace, receiver);
return assignment_testable::test_result::WEAKLY_ASSIGNABLE;
} catch (exceptions::invalid_request_exception& e) {
return assignment_testable::test_result::NOT_ASSIGNABLE;
}
}
virtual sstring assignment_testable_source_context() const override {
return to_string();
}
virtual sstring to_string() const override {
auto kv_to_str = [] (auto&& kv) { return sprint("%s:%s", kv.first, kv.second); };
return sprint("{%s}", ::join(", ", _entries | boost::adaptors::transformed(kv_to_str)));
}
};
// Same purpose than Lists.DelayedValue, except we do handle bind marker in that case
class delayed_value : public non_terminal {
user_type _type;
std::vector<shared_ptr<term>> _values;
public:
delayed_value(user_type type, std::vector<shared_ptr<term>> values)
: _type(std::move(type)), _values(std::move(values)) {
}
virtual bool uses_function(const sstring& ks_name, const sstring& function_name) const override {
return boost::algorithm::any_of(_values,
std::bind(&term::uses_function, std::placeholders::_1, std::cref(ks_name), std::cref(function_name)));
}
virtual bool contains_bind_marker() const override {
return boost::algorithm::any_of(_values, std::mem_fn(&term::contains_bind_marker));
}
virtual void collect_marker_specification(shared_ptr<variable_specifications> bound_names) {
for (auto&& v : _values) {
v->collect_marker_specification(bound_names);
}
}
private:
std::vector<bytes_opt> bind_internal(const query_options& options) {
auto sf = options.get_serialization_format();
std::vector<bytes_opt> buffers;
for (size_t i = 0; i < _type->size(); ++i) {
buffers.push_back(_values[i]->bind_and_get(options));
// Inside UDT values, we must force the serialization of collections to v3 whatever protocol
// version is in use since we're going to store directly that serialized value.
if (sf != serialization_format::use_32_bit() && _type->field_type(i)->is_collection() && buffers.back()) {
auto&& ctype = static_pointer_cast<collection_type_impl>(_type->field_type(i));
buffers.back() = ctype->reserialize(sf, serialization_format::use_32_bit(), bytes_view(*buffers.back()));
}
}
return buffers;
}
public:
virtual shared_ptr<terminal> bind(const query_options& options) override {
return ::make_shared<constants::value>(bind_and_get(options));
}
virtual bytes_opt bind_and_get(const query_options& options) override {
return user_type_impl::build_value(bind_internal(options));
}
};
};
}