/* * Copyright (C) 2014 ScyllaDB */ /* * This file is part of Scylla. * * Scylla is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Scylla is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Scylla. If not, see . */ #include "functions.hh" #include "function_call.hh" #include "token_fct.hh" #include "cql3/maps.hh" #include "cql3/sets.hh" #include "cql3/lists.hh" #include "cql3/constants.hh" namespace cql3 { namespace functions { thread_local std::unordered_multimap> functions::_declared = init(); std::unordered_multimap> functions::init() { std::unordered_multimap> ret; auto declare = [&ret] (shared_ptr f) { ret.emplace(f->name(), f); }; declare(aggregate_fcts::make_count_rows_function()); declare(time_uuid_fcts::make_now_fct()); declare(time_uuid_fcts::make_min_timeuuid_fct()); declare(time_uuid_fcts::make_max_timeuuid_fct()); declare(time_uuid_fcts::make_date_of_fct()); declare(time_uuid_fcts::make_unix_timestamp_of_fcf()); declare(make_uuid_fct()); for (auto&& type : cql3_type::values()) { // Note: because text and varchar ends up being synonimous, our automatic makeToBlobFunction doesn't work // for varchar, so we special case it below. We also skip blob for obvious reasons. if (type == cql3_type::varchar || type == cql3_type::blob) { continue; } // counters are not supported yet if (type->is_counter()) { warn(unimplemented::cause::COUNTERS); continue; } declare(make_to_blob_function(type->get_type())); declare(make_from_blob_function(type->get_type())); } declare(aggregate_fcts::make_count_function()); declare(aggregate_fcts::make_max_function()); declare(aggregate_fcts::make_min_function()); declare(aggregate_fcts::make_count_function()); declare(aggregate_fcts::make_max_function()); declare(aggregate_fcts::make_min_function()); //FIXME: //declare(aggregate_fcts::make_count_function()); //declare(aggregate_fcts::make_max_function()); //declare(aggregate_fcts::make_min_function()); // FIXME: more count/min/max declare(make_varchar_as_blob_fct()); declare(make_blob_as_varchar_fct()); declare(aggregate_fcts::make_sum_function()); declare(aggregate_fcts::make_sum_function()); declare(aggregate_fcts::make_avg_function()); declare(aggregate_fcts::make_avg_function()); #if 0 declare(AggregateFcts.sumFunctionForFloat); declare(AggregateFcts.sumFunctionForDouble); declare(AggregateFcts.sumFunctionForDecimal); declare(AggregateFcts.sumFunctionForVarint); declare(AggregateFcts.avgFunctionForFloat); declare(AggregateFcts.avgFunctionForDouble); declare(AggregateFcts.avgFunctionForVarint); declare(AggregateFcts.avgFunctionForDecimal); #endif // also needed for smp: #if 0 MigrationManager.instance.register(new FunctionsMigrationListener()); #endif return ret; } shared_ptr functions::make_arg_spec(const sstring& receiver_ks, const sstring& receiver_cf, const function& fun, size_t i) { auto&& name = boost::lexical_cast(fun.name()); std::transform(name.begin(), name.end(), name.begin(), ::tolower); return ::make_shared(receiver_ks, receiver_cf, ::make_shared(sprint("arg%d(%s)", i, name), true), fun.arg_types()[i]); } int functions::get_overload_count(const function_name& name) { return _declared.count(name); } shared_ptr functions::get(database& db, const sstring& keyspace, const function_name& name, const std::vector>& provided_args, const sstring& receiver_ks, const sstring& receiver_cf) { static const function_name TOKEN_FUNCTION_NAME = function_name::native_function("token"); if (name.has_keyspace() ? name == TOKEN_FUNCTION_NAME : name.name == TOKEN_FUNCTION_NAME.name) { return ::make_shared(db.find_schema(receiver_ks, receiver_cf)); } std::vector> candidates; auto&& add_declared = [&] (function_name fn) { auto&& fns = _declared.equal_range(fn); for (auto i = fns.first; i != fns.second; ++i) { candidates.push_back(i->second); } }; if (!name.has_keyspace()) { // add 'SYSTEM' (native) candidates add_declared(name.as_native_function()); add_declared(function_name(keyspace, name.name)); } else { // function name is fully qualified (keyspace + name) add_declared(name); } if (candidates.empty()) { return {}; } // Fast path if there is only one choice if (candidates.size() == 1) { auto fun = std::move(candidates[0]); validate_types(db, keyspace, fun, provided_args, receiver_ks, receiver_cf); return fun; } std::vector> compatibles; for (auto&& to_test : candidates) { auto r = match_arguments(db, keyspace, to_test, provided_args, receiver_ks, receiver_cf); switch (r) { case assignment_testable::test_result::EXACT_MATCH: // We always favor exact matches return to_test; case assignment_testable::test_result::WEAKLY_ASSIGNABLE: compatibles.push_back(std::move(to_test)); break; default: ; }; } if (compatibles.empty()) { throw exceptions::invalid_request_exception( sprint("Invalid call to function %s, none of its type signatures match (known type signatures: %s)", name, join(", ", candidates))); } if (compatibles.size() > 1) { throw exceptions::invalid_request_exception( sprint("Ambiguous call to function %s (can be matched by following signatures: %s): use type casts to disambiguate", name, join(", ", compatibles))); } return std::move(compatibles[0]); } std::vector> functions::find(const function_name& name) { auto range = _declared.equal_range(name); std::vector> ret; for (auto i = range.first; i != range.second; ++i) { ret.push_back(i->second); } return ret; } shared_ptr functions::find(const function_name& name, const std::vector& arg_types) { assert(name.has_keyspace()); // : "function name not fully qualified"; for (auto&& f : find(name)) { if (type_equals(f->arg_types(), arg_types)) { return f; } } return {}; } // This method and matchArguments are somewhat duplicate, but this method allows us to provide more precise errors in the common // case where there is no override for a given function. This is thus probably worth the minor code duplication. void functions::validate_types(database& db, const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, const sstring& receiver_cf) { if (provided_args.size() != fun->arg_types().size()) { throw exceptions::invalid_request_exception( sprint("Invalid number of arguments in call to function %s: %d required but %d provided", fun->name(), fun->arg_types().size(), provided_args.size())); } for (size_t i = 0; i < provided_args.size(); ++i) { auto&& provided = provided_args[i]; // If the concrete argument is a bind variables, it can have any type. // We'll validate the actually provided value at execution time. if (!provided) { continue; } auto&& expected = make_arg_spec(receiver_ks, receiver_cf, *fun, i); if (!is_assignable(provided->test_assignment(db, keyspace, expected))) { throw exceptions::invalid_request_exception( sprint("Type error: %s cannot be passed as argument %d of function %s of type %s", provided, i, fun->name(), expected->type->as_cql3_type())); } } } assignment_testable::test_result functions::match_arguments(database& db, const sstring& keyspace, shared_ptr fun, const std::vector>& provided_args, const sstring& receiver_ks, const sstring& receiver_cf) { if (provided_args.size() != fun->arg_types().size()) { return assignment_testable::test_result::NOT_ASSIGNABLE; } // It's an exact match if all are exact match, but is not assignable as soon as any is non assignable. auto res = assignment_testable::test_result::EXACT_MATCH; for (size_t i = 0; i < provided_args.size(); ++i) { auto&& provided = provided_args[i]; if (!provided) { res = assignment_testable::test_result::WEAKLY_ASSIGNABLE; continue; } auto&& expected = make_arg_spec(receiver_ks, receiver_cf, *fun, i); auto arg_res = provided->test_assignment(db, keyspace, expected); if (arg_res == assignment_testable::test_result::NOT_ASSIGNABLE) { return assignment_testable::test_result::NOT_ASSIGNABLE; } if (arg_res == assignment_testable::test_result::WEAKLY_ASSIGNABLE) { res = assignment_testable::test_result::WEAKLY_ASSIGNABLE; } } return res; } bool functions::type_equals(const std::vector& t1, const std::vector& t2) { #if 0 if (t1.size() != t2.size()) return false; for (int i = 0; i < t1.size(); i ++) if (!typeEquals(t1.get(i), t2.get(i))) return false; return true; #endif abort(); } bool function_call::uses_function(const sstring& ks_name, const sstring& function_name) const { return _fun->uses_function(ks_name, function_name); } void function_call::collect_marker_specification(shared_ptr bound_names) { for (auto&& t : _terms) { t->collect_marker_specification(bound_names); } } shared_ptr function_call::bind(const query_options& options) { return make_terminal(_fun, to_bytes_opt(bind_and_get(options)), options.get_cql_serialization_format()); } bytes_view_opt function_call::bind_and_get(const query_options& options) { std::vector buffers; buffers.reserve(_terms.size()); for (auto&& t : _terms) { // For now, we don't allow nulls as argument as no existing function needs it and it // simplify things. auto val = t->bind_and_get(options); if (!val) { throw exceptions::invalid_request_exception(sprint("Invalid null value for argument to %s", *_fun)); } buffers.push_back(std::move(to_bytes_opt(val))); } auto result = execute_internal(options.get_cql_serialization_format(), *_fun, std::move(buffers)); return options.make_temporary(result); } bytes_opt function_call::execute_internal(cql_serialization_format sf, scalar_function& fun, std::vector params) { bytes_opt result = fun.execute(sf, params); try { // Check the method didn't lied on it's declared return type if (result) { fun.return_type()->validate(*result); } return result; } catch (marshal_exception e) { throw runtime_exception(sprint("Return of function %s (%s) is not a valid value for its declared return type %s", fun, to_hex(result), *fun.return_type()->as_cql3_type() )); } } bool function_call::contains_bind_marker() const { for (auto&& t : _terms) { if (t->contains_bind_marker()) { return true; } } return false; } shared_ptr function_call::make_terminal(shared_ptr fun, bytes_opt result, cql_serialization_format sf) { if (!dynamic_pointer_cast(fun->return_type())) { return ::make_shared(std::move(result)); } auto ctype = static_pointer_cast(fun->return_type()); bytes_view res; if (result) { res = *result; } if (&ctype->_kind == &collection_type_impl::kind::list) { return make_shared(lists::value::from_serialized(std::move(res), static_pointer_cast(ctype), sf)); } else if (&ctype->_kind == &collection_type_impl::kind::set) { return make_shared(sets::value::from_serialized(std::move(res), static_pointer_cast(ctype), sf)); } else if (&ctype->_kind == &collection_type_impl::kind::map) { return make_shared(maps::value::from_serialized(std::move(res), static_pointer_cast(ctype), sf)); } abort(); } ::shared_ptr function_call::raw::prepare(database& db, const sstring& keyspace, ::shared_ptr receiver) { std::vector> args; args.reserve(_terms.size()); std::transform(_terms.begin(), _terms.end(), std::back_inserter(args), [] (auto&& x) -> shared_ptr { return x; }); auto&& fun = functions::functions::get(db, keyspace, _name, args, receiver->ks_name, receiver->cf_name); if (!fun) { throw exceptions::invalid_request_exception(sprint("Unknown function %s called", _name)); } if (fun->is_aggregate()) { throw exceptions::invalid_request_exception("Aggregation function are not supported in the where clause"); } // Can't use static_pointer_cast<> because function is a virtual base class of scalar_function auto&& scalar_fun = dynamic_pointer_cast(fun); // Functions.get() will complain if no function "name" type check with the provided arguments. // We still have to validate that the return type matches however if (!receiver->type->is_value_compatible_with(*scalar_fun->return_type())) { throw exceptions::invalid_request_exception(sprint("Type error: cannot assign result of function %s (type %s) to %s (type %s)", fun->name(), fun->return_type()->as_cql3_type(), receiver->name, receiver->type->as_cql3_type())); } if (scalar_fun->arg_types().size() != _terms.size()) { throw exceptions::invalid_request_exception(sprint("Incorrect number of arguments specified for function %s (expected %d, found %d)", fun->name(), fun->arg_types().size(), _terms.size())); } std::vector> parameters; parameters.reserve(_terms.size()); bool all_terminal = true; for (size_t i = 0; i < _terms.size(); ++i) { auto&& t = _terms[i]->prepare(db, keyspace, functions::make_arg_spec(receiver->ks_name, receiver->cf_name, *scalar_fun, i)); if (dynamic_cast(t.get())) { all_terminal = false; } parameters.push_back(t); } // If all parameters are terminal and the function is pure, we can // evaluate it now, otherwise we'd have to wait execution time if (all_terminal && scalar_fun->is_pure()) { return make_terminal(scalar_fun, execute(*scalar_fun, parameters), query_options::DEFAULT.get_cql_serialization_format()); } else { return ::make_shared(scalar_fun, parameters); } } bytes_opt function_call::raw::execute(scalar_function& fun, std::vector> parameters) { std::vector buffers; buffers.reserve(parameters.size()); for (auto&& t : parameters) { assert(dynamic_cast(t.get())); auto&& param = static_cast(t.get())->get(query_options::DEFAULT); buffers.push_back(std::move(param)); } return execute_internal(cql_serialization_format::internal(), fun, buffers); } assignment_testable::test_result function_call::raw::test_assignment(database& db, const sstring& keyspace, shared_ptr receiver) { // Note: Functions.get() will return null if the function doesn't exist, or throw is no function matching // the arguments can be found. We may get one of those if an undefined/wrong function is used as argument // of another, existing, function. In that case, we return true here because we'll throw a proper exception // later with a more helpful error message that if we were to return false here. try { auto&& fun = functions::get(db, keyspace, _name, _terms, receiver->ks_name, receiver->cf_name); if (fun && receiver->type->equals(fun->return_type())) { return assignment_testable::test_result::EXACT_MATCH; } else if (!fun || receiver->type->is_value_compatible_with(*fun->return_type())) { return assignment_testable::test_result::WEAKLY_ASSIGNABLE; } else { return assignment_testable::test_result::NOT_ASSIGNABLE; } } catch (exceptions::invalid_request_exception& e) { return assignment_testable::test_result::WEAKLY_ASSIGNABLE; } } sstring function_call::raw::to_string() const { return sprint("%s(%s)", _name, join(", ", _terms)); } } }