Files
scylladb/lang/lua.cc
Yaniv Michael Kaul ead9961783 cql: vector: fix vector dimension type
Switch vector dimension handling to fixed-width `uint32_t` type,
update parsing/validation, and add boundary tests.

The dimension is parsed as `unsigned long` at first which is guaranteed
to be **at least** 32-bit long, which is safe to downcast to `uint32_t`.

Move `MAX_VECTOR_DIMENSION` from `cql3_type::raw_vector` to `cql3_type`
to ensure public visibility for checks outside the class.

Add tests to verify the type boundaries.

Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-223

Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
Co-authored-by: Dawid Pawlik <dawid.pawlik@scylladb.com>

Closes scylladb/scylladb#28762
2026-02-26 14:46:53 +02:00

1154 lines
41 KiB
C++

/*
* Copyright (C) 2019-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include <boost/date_time/gregorian/greg_date.hpp>
#include <boost/date_time/posix_time/posix_time.hpp>
#include <random>
#include "lua.hh"
#include "lang/lua_scylla_types.hh"
#include "exceptions/exceptions.hh"
#include "types/concrete_types.hh"
#include "utils/assert.hh"
#include "utils/utf8.hh"
#include "utils/ascii.hh"
#include "utils/date.h"
#include <seastar/core/align.hh>
#include <lua.hpp>
#include "seastarx.hh"
// Lua 5.4 added an extra parameter to lua_resume
#if LUA_VERSION_NUM >= 504
# define LUA_504_PLUS(x...) x
#else
# define LUA_504_PLUS(x...)
#endif
// Lua 5.5 added a seed parameter to lua_newstate
#if LUA_VERSION_NUM >= 505
# define LUA_505_PLUS(x...) x
#else
# define LUA_505_PLUS(x...)
#endif
using namespace seastar;
using namespace lua;
static logging::logger lua_logger("lua");
namespace {
struct alloc_state {
size_t allocated = 0;
size_t max;
size_t max_contiguous;
alloc_state(size_t max, size_t max_contiguous)
: max(max)
, max_contiguous(max_contiguous) {
// The max and max_contiguous limits are responsible for avoiding overflows.
SCYLLA_ASSERT(max + max_contiguous >= max);
}
};
struct lua_closer {
void operator()(lua_State* l) {
lua_close(l);
}
};
static const char scylla_decimal_metatable_name[] = "Scylla.decimal";
class lua_slice_state {
std::unique_ptr<alloc_state> a_state;
std::unique_ptr<lua_State, lua_closer> _l;
public:
lua_slice_state(std::unique_ptr<alloc_state> a_state, std::unique_ptr<lua_State, lua_closer> l)
: a_state(std::move(a_state))
, _l(std::move(l)) {}
operator lua_State*() { return _l.get(); }
};
}
static void* lua_alloc(void* ud, void* ptr, size_t osize, size_t nsize) {
auto* s = reinterpret_cast<alloc_state*>(ud);
// avoid realloc(nullptr, 0), which would allocate.
if (nsize == 0 && ptr == nullptr) {
return nullptr;
}
if (nsize > s->max_contiguous) {
return nullptr;
}
size_t next = s->allocated + nsize;
// The max and max_contiguous limits should be small enough to avoid overflows.
SCYLLA_ASSERT(next >= s->allocated);
if (ptr) {
next -= osize;
}
if (next > s->max) {
lua_logger.info("allocation failed. already allocated = {}, next total = {}, max = {}", s->allocated, next, s->max);
return nullptr;
}
// FIXME: Given that we have osize, we can probably do better when
// SEASTAR_DEFAULT_ALLOCATOR is false
void* ret = realloc(ptr, nsize);
if (nsize == 0 || ret != nullptr) {
s->allocated = next;
}
return ret;
}
static const luaL_Reg loadedlibs[] = {
{"_G", luaopen_base},
{LUA_STRLIBNAME, luaopen_string},
{LUA_COLIBNAME, luaopen_coroutine},
{LUA_TABLIBNAME, luaopen_table},
{NULL, NULL},
};
static void debug_hook(lua_State* l, lua_Debug* ar) {
if (!need_preempt()) {
return;
}
// The lua manual says that only count and line events can yield. Of those we only use count.
if (ar->event != LUA_HOOKCOUNT) {
// Set the hook to stop at the very next lua instruction, where we will be able to yield.
lua_sethook(l, debug_hook, LUA_MASKCOUNT, 1);
return;
}
if (lua_yield(l, 0)) {
SCYLLA_ASSERT(0 && "lua_yield failed");
}
}
static lua_slice_state new_lua(const lua::runtime_config& cfg) {
auto a_state = std::make_unique<alloc_state>(cfg.max_bytes, cfg.max_contiguous);
#if LUA_VERSION_NUM >= 505
static thread_local std::default_random_engine rng{std::random_device{}()};
auto seed = rng();
#endif
std::unique_ptr<lua_State, lua_closer> l{lua_newstate(lua_alloc, a_state.get() LUA_505_PLUS(, seed))};
if (!l) {
throw std::runtime_error("could not create lua state");
}
return lua_slice_state{std::move(a_state), std::move(l)};
}
static int string_writer(lua_State*, const void* p, size_t size, void* data) {
luaL_addlstring(reinterpret_cast<luaL_Buffer*>(data), reinterpret_cast<const char*>(p), size);
return 0;
}
static int compile_l(lua_State* l) {
const auto& script = *reinterpret_cast<sstring*>(lua_touserdata(l, 1));
luaL_Buffer buf;
luaL_buffinit(l, &buf);
if (luaL_loadbufferx(l, script.c_str(), script.size(), "<internal>", "t")) {
lua_error(l);
}
if (lua_dump(l, string_writer, &buf, true)) {
luaL_error(l, "lua_dump failed");
}
luaL_pushresult(&buf);
return 1;
}
sstring lua::compile(const runtime_config& cfg, const std::vector<sstring>& arg_names, sstring script) {
if (!arg_names.empty()) {
// In Lua, all chunks are compiled to vararg functions. To use
// the UDF argument names, we start the chunk with "local
// arg1,arg2,etc = ...", which captures the arguments when the
// chunk is called.
std::ostringstream os;
os << "local ";
for (int i = 0, n = arg_names.size(); i < n; ++i) {
if (i != 0) {
os << ",";
}
os << arg_names[i];
}
os << " = ...;\n" << script;
script = os.str();
}
lua_slice_state l = new_lua(cfg);
// Run the load from lua_pcall so we don't have to handle longjmp.
lua_pushcfunction(l, compile_l);
lua_pushlightuserdata(l, &script);
if (lua_pcall(l, 1, 1, 0)) {
throw exceptions::invalid_request_exception(std::string("could not compile: ") + lua_tostring(l, -1));
}
size_t len;
const char* p = lua_tolstring(l, -1, &len);
return sstring(p, len);
}
static big_decimal* get_decimal(lua_State* l, int arg) {
constexpr size_t alignment = alignof(big_decimal);
char* p = reinterpret_cast<char*>(luaL_testudata(l, arg, scylla_decimal_metatable_name));
if (p) {
return reinterpret_cast<big_decimal*>(align_up(p, alignment));
}
return nullptr;
}
static void push_big_decimal(lua_State* l, const big_decimal& v) {
auto* p = aligned_user_data<big_decimal>(l);
new (p) big_decimal(v);
luaL_setmetatable(l, scylla_decimal_metatable_name);
}
static void push_cpp_int(lua_State* l, const utils::multiprecision_int& v) {
push_big_decimal(l, big_decimal(0, v));
}
struct lua_table {
};
template <typename Func>
using lua_visit_ret_type = std::invoke_result_t<Func, const double&>;
template <typename Func>
concept CanHandleRawLuaTypes = requires(Func f) {
{ f(*static_cast<const long long*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const double*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const big_decimal*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const std::string_view*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const lua_table*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
};
template <typename Func>
requires CanHandleRawLuaTypes<Func>
static auto visit_lua_raw_value(lua_State* l, int index, Func&& f) {
switch (lua_type(l, index)) {
case LUA_TNONE:
SCYLLA_ASSERT(0 && "Invalid index");
case LUA_TNUMBER:
if (lua_isinteger(l, index)) {
return f(lua_tointeger(l, index));
}
return f(lua_tonumber(l, index));
case LUA_TSTRING: {
size_t len;
const char* s = lua_tolstring(l, index, &len);
return f(std::string_view{s, len});
}
case LUA_TTABLE:
return f(lua_table{});
case LUA_TBOOLEAN:
case LUA_TFUNCTION:
case LUA_TNIL:
throw exceptions::invalid_request_exception("unexpected value");
case LUA_TUSERDATA:
return f(*get_decimal(l, index));
case LUA_TTHREAD:
case LUA_TLIGHTUSERDATA:
SCYLLA_ASSERT(0 && "We never make thread or light user data visible to scripts");
}
SCYLLA_ASSERT(0 && "invalid lua type");
}
template <typename Func>
static auto visit_decimal(const big_decimal &v, Func&& f) {
boost::multiprecision::cpp_rational r = v.as_rational();
const boost::multiprecision::cpp_int& dividend = numerator(r);
const boost::multiprecision::cpp_int& divisor = denominator(r);
if (dividend % divisor == 0) {
return f(utils::multiprecision_int(dividend/divisor));
}
return f(r.convert_to<double>());
}
template <typename Func>
concept CanHandleLuaTypes = requires(Func f) {
{ f(*static_cast<const double*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const utils::multiprecision_int*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const big_decimal*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const std::string_view*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
{ f(*static_cast<const lua_table*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
};
template <typename Func>
requires CanHandleLuaTypes<Func>
static auto visit_lua_value(lua_State* l, int index, Func&& f) {
struct visitor {
lua_State* l;
int index;
Func& f;
auto operator()(const long long& v) { return f(utils::multiprecision_int(v)); }
auto operator()(const utils::multiprecision_int& v) { return f(v); }
auto operator()(const double& v) {
auto min = double(std::numeric_limits<long long>::min());
auto max = double(std::numeric_limits<long long>::max());
if (min <= v && v <= max && std::trunc(v) == v) {
return (*this)((long long)v);
}
// FIXME: We could use frexp to produce a decimal instead of a double
return f(v);
}
auto operator()(const std::string_view& v) {
big_decimal v2;
try {
v2 = big_decimal(v);
} catch (marshal_exception&) {
// The string is not a valid big_decimal. Let Lua try to convert it to a double.
int isnum;
double d = lua_tonumberx(l, index, &isnum);
if (isnum) {
return (*this)(d);
}
return f(v);
}
return (*this)(v2);
}
auto operator()(const big_decimal& v) {
struct visitor {
Func& f;
const big_decimal &d;
auto operator()(const double&) { return f(d); }
auto operator()(const utils::multiprecision_int& v) { return f(v); }
};
return visit_decimal(v, visitor{f, v});
}
auto operator()(const lua_table& v) {
return f(v);
}
};
return visit_lua_raw_value(l, index, visitor{l, index, f});
}
template <typename Func>
static auto visit_lua_number(lua_State* l, int index, Func&& f) {
return visit_lua_value(l, index, make_visitor(
[] (const std::string_view& v) -> std::invoke_result_t<Func, double> {
throw exceptions::invalid_request_exception("value is not a number");
},
[] (const lua_table&) -> std::invoke_result_t<Func, double> {
throw exceptions::invalid_request_exception("value is not a number");
},
std::forward<Func>(f)
));
}
template <typename Func> static auto visit_lua_decimal(lua_State* l, int index, Func&& f) {
return visit_lua_number(l, index, make_visitor(
[&f](const utils::multiprecision_int& v) { return f(big_decimal(0, v)); },
[&f](const auto& v) { return f(v); }
));
}
static int decimal_gc(lua_State *l) {
std::destroy_at(get_decimal(l, 1));
return 0;
}
static double decimal_to_double(const big_decimal &d) {
return visit_decimal(d, [] (auto&& v) { return double(v); });
}
static const big_decimal& get_decimal_in_binary_op(lua_State* l) {
auto* a = get_decimal(l, 1);
if (a == nullptr) {
lua_insert(l, 1);
a = get_decimal(l, 1);
SCYLLA_ASSERT(a);
}
return *a;
}
template<typename Func>
static void visit_decimal_bin_op(lua_State* l, Func&& F) {
const auto& a = get_decimal_in_binary_op(l);
struct bin_op_visitor {
const big_decimal& a;
Func& F;
lua_State* l;
void operator()(const double& b) {
lua_pushnumber(l, F(decimal_to_double(a), b));
}
void operator()(const big_decimal& b) {
push_big_decimal(l, F(a, b));
}
};
visit_lua_decimal(l, -1, bin_op_visitor{a, F, l});
}
static int decimal_add(lua_State* l) {
visit_decimal_bin_op(l, [](auto&& a, auto&& b) { return a + b; });
return 1;
}
static int decimal_sub(lua_State* l) {
visit_decimal_bin_op(l, [](auto&& a, auto&& b) { return a - b; });
return 1;
}
static const struct luaL_Reg decimal_methods[] {
{"__gc", decimal_gc},
{"__add", decimal_add},
{"__sub", decimal_sub},
{nullptr, nullptr}
};
static int load_script_l(lua_State* l) {
const auto& bitcode = *reinterpret_cast<lua::bitcode_view*>(lua_touserdata(l, 1));
const auto& binary = bitcode.bitcode;
for (const luaL_Reg* lib = loadedlibs; lib->func; lib++) {
luaL_requiref(l, lib->name, lib->func, 1);
lua_pop(l, 1);
}
lua::register_metatables(l);
if (luaL_loadbufferx(l, binary.data(), binary.size(), "<internal>", "b")) {
lua_error(l);
}
return 1;
}
static lua_slice_state load_script(const lua::runtime_config& cfg, lua::bitcode_view binary) {
lua_slice_state l = new_lua(cfg);
// Run the initialization from lua_pcall so we don't have to
// handle longjmp. We know that a new state has a few reserved
// stack slots and the following push calls don't allocate.
lua_pushcfunction(l, load_script_l);
lua_pushlightuserdata(l, &binary);
if (lua_pcall(l, 1, 1, 0)) {
throw std::runtime_error(std::string("could not initiate: ") + lua_tostring(l, -1));
}
return l;
}
using millisecond = std::chrono::duration<double, std::milli>;
static auto now() { return std::chrono::system_clock::now(); }
static utils::multiprecision_int get_varint(lua_State* l, int index) {
return visit_lua_number(l, index, make_visitor(
[](const utils::multiprecision_int& v) { return v; },
[](const auto& v) -> utils::multiprecision_int{
throw exceptions::invalid_request_exception("value is not an integer");
}
));
}
static sstring get_string(lua_State *l, int index) {
return visit_lua_value(l, index, make_visitor(
[] (const lua_table&) -> sstring {
throw exceptions::invalid_request_exception("unexpected value");
},
[] (const utils::multiprecision_int& p) {
return sstring(p.str());
},
[] (const auto& v) {
return seastar::format("{}", v);
}));
}
static data_value convert_from_lua(lua_State* l, const data_type& type);
namespace {
struct lua_date_table {
// The lua date table is documented at https://www.lua.org/pil/22.1.html
// date::year uses a int64_t, but there is no reason to try to
// support 64 bit years. In practice the limitations are
// * year_month_day::to_days hits a signed integer overflow for
// large years.
// * boost::gregorian only supports the years [1400,9999]
int32_t year;
// Both date::month and date::day use unsigned char.
unsigned char month;
unsigned char day;
std::optional<int32_t> hour;
std::optional<int32_t> minute;
std::optional<int32_t> second;
};
static lua_date_table get_lua_date_table(lua_State* l, int index) {
std::optional<int32_t> year;
std::optional<unsigned char> month;
std::optional<unsigned char> day;
std::optional<int32_t> hour;
std::optional<int32_t> minute;
std::optional<int32_t> second;
lua_pushnil(l);
while (lua_next(l, index - 1) != 0) {
auto k = get_string(l, index - 1);
auto v = get_varint(l, index);
lua_pop(l, 1);
if (k == "month") {
month = (unsigned char)v;
if (*month != v) {
throw exceptions::invalid_request_exception(seastar::format("month is too large: '{}'", v.str()));
}
} else if (k == "day") {
day = (unsigned char)v;
if (*day != v) {
throw exceptions::invalid_request_exception(seastar::format("day is too large: '{}'", v.str()));
}
} else {
int32_t vint(v);
if (vint != v) {
throw exceptions::invalid_request_exception(seastar::format("{} is too large: '{}'", k, v.str()));
}
if (k == "year") {
year = vint;
} else if (k == "hour") {
hour = vint;
} else if (k == "min") {
minute = vint;
} else if (k == "sec") {
second = vint;
} else {
throw exceptions::invalid_request_exception(format("invalid date table field: '{}'", k));
}
}
}
if (!year || !month || !day) {
throw exceptions::invalid_request_exception("date table must have year, month and day");
}
return lua_date_table{*year, *month, *day, hour, minute, second};
}
struct simple_date_return_visitor {
lua_State* l;
template <typename T>
uint32_t operator()(const T&) {
throw exceptions::invalid_request_exception("date must be a string, integer or date table");
}
uint32_t operator()(const utils::multiprecision_int& v) {
if (v > std::numeric_limits<uint32_t>::max()) {
throw exceptions::invalid_request_exception("date value must fit in 32 bits");
}
return uint32_t(v);
}
uint32_t operator()(const std::string_view& v) {
return simple_date_type_impl::from_string_view(v);
}
uint32_t operator()(const lua_table&);
};
struct timestamp_return_visitor {
lua_State* l;
template <typename T>
db_clock::time_point operator()(const T&) {
throw exceptions::invalid_request_exception("timestamp must be a string, integer or date table");
}
db_clock::time_point operator()(const utils::multiprecision_int& v) {
int64_t v2 = int64_t(v);
if (v2 == v) {
return db_clock::time_point(db_clock::duration(v2));
}
throw exceptions::invalid_request_exception("timestamp value must fit in signed 64 bits");
}
db_clock::time_point operator()(const std::string_view& v) {
return timestamp_type_impl::from_string_view(v);
}
db_clock::time_point operator()(const lua_table&);
};
struct from_lua_visitor {
lua_State* l;
data_value operator()(const reversed_type_impl& t) {
// This is unreachable since reversed_type_impl is used only
// in the tables. The function return the underlying type.
abort();
}
data_value operator()(const empty_type_impl& t) {
// This is unreachable since empty types are not user visible.
abort();
}
data_value operator()(const decimal_type_impl& t) {
struct visitor {
big_decimal operator()(const double& b) {
throw exceptions::invalid_request_exception("value is not a decimal");
}
big_decimal operator()(const big_decimal& b) {
return b;
}
};
return visit_lua_decimal(l, -1, visitor{});
}
data_value operator()(const varint_type_impl& t) {
return get_varint(l, -1);
}
data_value operator()(const duration_type_impl& t) {
return visit_lua_value(l, -1, make_visitor(
[] (const auto&) -> cql_duration {
throw exceptions::invalid_request_exception("a duration must be of the form { months = v1, days = v2, nanoseconds = v3 }");
},
[] (const std::string_view& v) {
return cql_duration(v);
},
[this] (const lua_table&) {
int32_t months = 0;
int32_t days = 0;
int64_t nanoseconds = 0;
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto k = get_string(l, -2);
auto v = get_varint(l, -1);
lua_pop(l, 1);
if (k == "months") {
months = int32_t(v);
if (v != months) {
throw exceptions::invalid_request_exception(seastar::format("{} months doesn't fit in a 32 bit integer", v.str()));
}
} else if (k == "days") {
days = int32_t(v);
if (v != days) {
throw exceptions::invalid_request_exception(seastar::format("{} days doesn't fit in a 32 bit integer", v.str()));
}
} else if (k == "nanoseconds") {
nanoseconds = int64_t(v);
if (v != nanoseconds) {
throw exceptions::invalid_request_exception(seastar::format("{} nanoseconds doesn't fit in a 64 bit integer", v.str()));
}
} else {
throw exceptions::invalid_request_exception(format("invalid duration field: '{}'", k));
}
}
return cql_duration(months_counter(months), days_counter(days), nanoseconds_counter(nanoseconds));
}));
}
data_value operator()(const set_type_impl& t) {
std::vector<data_value> elements;
const data_type& element_type = t.get_elements_type();
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
if (!lua_toboolean(l, -1)) {
throw exceptions::invalid_request_exception("sets are represented with tables with true values");
}
lua_pop(l, 1);
elements.push_back(convert_from_lua(l, element_type));
}
std::sort(elements.begin(), elements.end(), [&](const data_value& a, const data_value& b) {
// FIXME: this is madness, we have to be able to compare without serializing!
return element_type->less(a.serialize_nonnull(), b.serialize_nonnull());
});
return make_set_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const map_type_impl& t) {
const data_type& key_type = t.get_keys_type();
const data_type& value_type = t.get_values_type();
using map_pair = std::pair<data_value, data_value>;
std::vector<map_pair> elements;
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto v = convert_from_lua(l, value_type);
lua_pop(l, 1);
auto k = convert_from_lua(l, key_type);
elements.push_back({k, v});
}
std::sort(elements.begin(), elements.end(), [&](const map_pair& a, const map_pair& b) {
// FIXME: this is madness, we have to be able to compare without serializing!
return key_type->less(a.first.serialize_nonnull(), b.first.serialize_nonnull());
});
return make_map_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const list_type_impl& t) {
if (!lua_istable(l, -1)) {
throw exceptions::invalid_request_exception("value is not a table");
}
const data_type& elements_type = t.get_elements_type();
using table_pair = std::pair<utils::multiprecision_int, data_value>;
std::vector<table_pair> pairs;
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto v = convert_from_lua(l, elements_type);
lua_pop(l, 1);
pairs.push_back({get_varint(l, -1), v});
}
std::sort(pairs.begin(), pairs.end(), [] (const table_pair& a, const table_pair& b) {
return a.first < b.first;
});
size_t num_elements = pairs.size();
std::vector<data_value> elements;
for (size_t i = 0; i < num_elements; ++i) {
if (utils::multiprecision_int(i + 1) != pairs[i].first) {
throw exceptions::invalid_request_exception("table is not a sequence");
}
elements.push_back(pairs[i].second);
}
return make_list_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const tuple_type_impl& t) {
if (!lua_istable(l, -1)) {
throw exceptions::invalid_request_exception("value is not a table");
}
size_t num_elements = t.size();
std::vector<std::optional<data_value>> opt_elements(num_elements);
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto k_varint = get_varint(l, -2);
if (k_varint > num_elements || k_varint < 1) {
throw exceptions::invalid_request_exception(
seastar::format("key {} is not valid for a sequence of size {}", k_varint.str(), num_elements));
}
size_t k = size_t(k_varint);
opt_elements[k - 1] = convert_from_lua(l, t.type(k - 1));
lua_pop(l, 1);
}
std::vector<data_value> elements;
elements.reserve(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
if (!opt_elements[i]) {
throw exceptions::invalid_request_exception(
format("key {} missing in sequence of size {}", i + 1, num_elements));
}
elements.push_back(*opt_elements[i]);
}
return make_tuple_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const vector_type_impl& t) {
if (!lua_istable(l, -1)) {
throw exceptions::invalid_request_exception("value is not a table");
}
const data_type& elements_type = t.get_elements_type();
vector_dimension_t num_elements = t.get_dimension();
using table_pair = std::pair<utils::multiprecision_int, data_value>;
std::vector<table_pair> pairs;
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto v = convert_from_lua(l, elements_type);
lua_pop(l, 1);
auto k = get_varint(l, -1);
if (k > num_elements || k < 1) {
throw exceptions::invalid_request_exception(
seastar::format("key {} is not valid for a sequence of size {}", k.str(), num_elements));
}
pairs.push_back({k, v});
}
std::sort(pairs.begin(), pairs.end(), [] (const table_pair& a, const table_pair& b) {
return a.first < b.first;
});
std::vector<data_value> elements;
for (size_t i = 0; i < num_elements; ++i) {
if (utils::multiprecision_int(i + 1) != pairs[i].first) {
throw exceptions::invalid_request_exception(
format("key {} missing in sequence of size {}", i + 1, num_elements));
}
elements.push_back(pairs[i].second);
}
return make_vector_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const user_type_impl& t) {
size_t num_fields = t.field_types().size();
std::unordered_map<sstring, std::pair<unsigned, data_type>> field_types;
field_types.reserve(num_fields);
for (unsigned i = 0; i < num_fields; ++i) {
field_types.insert({t.field_name_as_string(i), {i, t.field_type(i)}});
}
std::vector<std::optional<data_value>> opt_elements(num_fields);
lua_pushnil(l);
while (lua_next(l, -2) != 0) {
auto s = get_string(l, -2);
auto iter = field_types.find(s);
if (iter == field_types.end()) {
throw exceptions::invalid_request_exception(format("invalid UDT field '{}'", s));
}
const auto &p = iter->second;
auto v = convert_from_lua(l, p.second);
lua_pop(l, 1);
opt_elements[p.first] = std::move(v);
}
std::vector<data_value> elements;
elements.reserve(num_fields);
for (size_t i = 0; i < num_fields; ++i) {
if (!opt_elements[i]) {
throw exceptions::invalid_request_exception(
format("key {} missing in udt {}", t.field_name_as_string(i), t.get_name_as_string()));
}
elements.push_back(*opt_elements[i]);
}
return make_user_value(t.shared_from_this(), std::move(elements));
}
data_value operator()(const inet_addr_type_impl& t) {
return t.from_string_view(get_string(l, -1));
}
data_value operator()(const uuid_type_impl&) {
return uuid_type_impl::from_string_view(get_string(l, -1));
}
data_value operator()(const timeuuid_type_impl&) {
return timeuuid_native_type{timeuuid_type_impl::from_string_view(get_string(l, -1))};
}
data_value operator()(const bytes_type_impl& t) {
sstring v = get_string(l, -1);
return data_value(bytes(reinterpret_cast<const int8_t*>(v.data()), v.size()));
}
data_value operator()(const utf8_type_impl& t) {
sstring s = get_string(l, -1);
auto error_pos = utils::utf8::validate_with_error_position(reinterpret_cast<uint8_t*>(s.data()), s.size());
if (error_pos) {
throw exceptions::invalid_request_exception(format("value is not valid utf8, invalid character at byte offset {}", *error_pos));
}
return s;
}
data_value operator()(const ascii_type_impl& t) {
sstring s = get_string(l, -1);
if (utils::ascii::validate(reinterpret_cast<uint8_t*>(s.data()), s.size())) {
return ascii_native_type{std::move(s)};
}
throw exceptions::invalid_request_exception("value is not valid ascii");
}
data_value operator()(const boolean_type_impl& t) {
return bool(lua_toboolean(l, -1));
}
template <typename T> data_value operator()(const floating_type_impl<T>& t) {
return visit_lua_number(l, -1, make_visitor(
[] (const big_decimal& v) -> T { return decimal_to_double(v); },
[] (const auto& v) { return T(v); }
));
}
int64_t get_integer() {
return from_varint_to_integer(get_varint(l, -1));
}
data_value operator()(const timestamp_date_base_class& t) {
return visit_lua_value(l, -1, timestamp_return_visitor{l});
}
data_value operator()(const time_type_impl& t) {
return time_native_type{visit_lua_value(l, -1, make_visitor(
[] (const auto&) -> int64_t {
throw exceptions::invalid_request_exception("time must be a string or an integer");
},
[] (const utils::multiprecision_int& v) {
int64_t v2 = int64_t(v);
if (v2 == v) {
return v2;
}
throw exceptions::invalid_request_exception("time value must fit in signed 64 bits");
},
[] (const std::string_view& v) {
return time_type_impl::from_string_view(v);
}
))};
}
data_value operator()(const counter_type_impl&) {
// No data_value ever has a counter type, it is represented
// with long_type instead.
return get_integer();
}
template <typename T> data_value operator()(const integer_type_impl<T>& t) {
return T(get_integer());
}
data_value operator()(const simple_date_type_impl& t) {
return simple_date_native_type{visit_lua_value(l, -1, simple_date_return_visitor{l})};
}
};
uint32_t simple_date_return_visitor::operator()(const lua_table&) {
auto table = get_lua_date_table(l, -1);
if (table.hour || table.minute || table.second) {
throw exceptions::invalid_request_exception("date type has no hour, minute or second");
}
date::year_month_day ymd{date::year{table.year}, date::month{table.month}, date::day{table.day}};
int64_t days = date::local_days(ymd).time_since_epoch().count() + (1UL << 31);
return (*this)(utils::multiprecision_int(days));
}
db_clock::time_point timestamp_return_visitor::operator()(const lua_table&) {
auto table = get_lua_date_table(l, -1);
boost::gregorian::date date(table.year, table.month, table.day);
boost::posix_time::time_duration time(table.hour.value_or(12), table.minute.value_or(0), table.second.value_or(0));
boost::posix_time::ptime timestamp(date, time);
int64_t msec = (timestamp - boost::posix_time::from_time_t(0)).total_milliseconds();
return (*this)(utils::multiprecision_int(msec));
}
}
static data_value convert_from_lua(lua_State* l, const data_type& type) {
if (lua_isnil(l, -1)) {
return data_value::make_null(type);
}
return ::visit(*type, from_lua_visitor{l});
}
static bytes_opt convert_return(lua_slice_state &l, const data_type& return_type) {
int num_return_vals = lua_gettop(l);
if (num_return_vals != 1) {
throw exceptions::invalid_request_exception(
format("{} values returned, expected {}", num_return_vals, 1));
}
// FIXME: It should be possible to avoid creating the data_value,
// or even better, change the function::execute interface to
// return a data_value instead of bytes_opt.
return convert_from_lua(l, return_type).serialize();
}
void lua::push_sstring(lua_State* l, const sstring& v) {
lua_pushlstring(l, v.c_str(), v.size());
}
static void push_argument(lua_State* l, const data_value& arg);
namespace {
struct to_lua_visitor {
lua_State* l;
void operator()(const varint_type_impl& t, const emptyable<utils::multiprecision_int>* v) {
push_cpp_int(l, *v);
}
void operator()(const decimal_type_impl& t, const emptyable<big_decimal>* v) {
push_big_decimal(l, *v);
}
void operator()(const counter_type_impl& t, const void* v) {
// This is unreachable since deserialize_visitor for
// counter_type_impl return a long.
abort();
}
void operator()(const empty_type_impl& t, const void* v) {
// This is unreachable since empty types are not user visible.
abort();
}
void operator()(const reversed_type_impl& t, const void* v) {
// This is unreachable since reversed_type_impl is used only
// in the tables. The function gets the underlying type.
abort();
}
void operator()(const map_type_impl& t, const std::vector<std::pair<data_value, data_value>>* v) {
// returns the table { k1 = v1, k2 = v2, ...}
lua_createtable(l, 0, v->size());
for (const auto& p : *v) {
push_argument(l, p.first);
push_argument(l, p.second);
lua_rawset(l, -3);
}
}
void operator()(const user_type_impl& t, const std::vector<data_value>* v) {
// returns the table { field1 = v1, field2 = v2, ...}
lua_createtable(l, 0, v->size());
for (int i = 0, n = v->size(); i < n; ++i) {
push_sstring(l, t.field_name_as_string(i));
push_argument(l, (*v)[i]);
lua_rawset(l, -3);
}
}
void operator()(const set_type_impl& t, const std::vector<data_value>* v) {
// returns the table { v1 = true, v2 = true, ...}
lua_createtable(l, 0, v->size());
for (const data_value& dv : *v) {
push_argument(l, dv);
lua_pushboolean(l, true);
lua_rawset(l, -3);
}
}
template <typename T>
void operator()(const concrete_type<std::vector<data_value>, T>& t, const std::vector<data_value>* v) {
// returns the table {v1, v2, ...}
lua_createtable(l, v->size(), 0);
int i = 0;
for (const data_value& dv : *v) {
push_argument(l, dv);
lua_rawseti(l, -2, ++i);
}
}
void operator()(const boolean_type_impl& t, const emptyable<bool>* v) {
lua_pushboolean(l, *v);
}
template <typename T>
void operator()(const floating_type_impl<T>& t, const emptyable<T>* v) {
// floats are converted to double
lua_pushnumber(l, *v);
}
template <typename T>
void operator()(const integer_type_impl<T>& t, const emptyable<T>* v) {
// Integers are converted to 64 bits
lua_pushinteger(l, *v);
}
void operator()(const bytes_type_impl& t, const bytes* v) {
// lua strings can hold arbitrary blobs
lua_pushlstring(l, reinterpret_cast<const char*>(v->c_str()), v->size());
}
void operator()(const string_type_impl& t, const sstring* v) {
push_sstring(l, *v);
}
void operator()(const time_type_impl& t, const emptyable<int64_t>* v) {
// nanoseconds since midnight
lua_pushinteger(l, *v);
}
void operator()(const timestamp_date_base_class& t, const timestamp_date_base_class::native_type* v) {
// milliseconds since epoch
lua_pushinteger(l, v->get().time_since_epoch().count());
}
void operator()(const simple_date_type_impl& t, const emptyable<uint32_t>* v) {
// number of days since epoch + 2^31
lua_pushinteger(l, *v);
}
void operator()(const duration_type_impl& t, const emptyable<cql_duration>* v) {
// returns the table { months = v1, days = v2, nanoseconds = v3 }
const cql_duration& d = v->get();
lua_createtable(l, 3, 0);
lua_pushinteger(l, d.months);
lua_setfield(l, -2, "months");
lua_pushinteger(l, d.days);
lua_setfield(l, -2, "days");
lua_pushinteger(l, d.nanoseconds);
lua_setfield(l, -2, "nanoseconds");
}
void operator()(const inet_addr_type_impl& t, const emptyable<seastar::net::inet_address>* v) {
// returns a string
sstring s = fmt::to_string(v->get());
push_sstring(l, s);
}
void operator()(const concrete_type<utils::UUID>&, const emptyable<utils::UUID>* v) {
// returns a string
push_sstring(l, fmt::to_string(v->get())) ;
}
};
}
static void push_argument(lua_State* l, const data_value& arg) {
if (arg.is_null()) {
lua_pushnil(l);
return;
}
::visit(arg, to_lua_visitor{l});
}
// run the script for at most max_instructions
future<bytes_opt> lua::run_script(lua::bitcode_view bitcode, const std::vector<data_value>& values, data_type return_type, const lua::runtime_config& cfg) {
lua_slice_state l = load_script(cfg, bitcode);
unsigned nargs = values.size();
if (!lua_checkstack(l, nargs)) {
throw std::runtime_error("could push args to the stack");
}
for (const data_value& arg : values) {
push_argument(l, arg);
}
// We don't update the timeout once we start executing the function
using millisecond = std::chrono::duration<double, std::milli>;
using duration = std::chrono::system_clock::duration;
duration elapsed{0};
duration timeout = std::chrono::duration_cast<duration>(millisecond(cfg.timeout_in_ms));
return repeat_until_value([l = std::move(l), elapsed, return_type, nargs, timeout = std::move(timeout)] () mutable {
// Set the hook before resuming. We have to do it here since the hook can reset itself
// if it detects we are spending too much time in C.
// The hook will be called after 1000 instructions.
lua_sethook(l, debug_hook, LUA_MASKCALL | LUA_MASKCOUNT, 1000);
auto start = ::now();
LUA_504_PLUS(int nresults;)
switch (lua_resume(l, nullptr, nargs LUA_504_PLUS(, &nresults))) {
case LUA_OK:
return make_ready_future<std::optional<bytes_opt>>(convert_return(l, return_type));
case LUA_YIELD: {
nargs = 0;
elapsed += ::now() - start;
if (elapsed > timeout) {
millisecond ms = elapsed;
throw exceptions::invalid_request_exception(format("lua execution timeout: {}ms elapsed", ms.count()));
}
return make_ready_future<std::optional<bytes_opt>>(std::nullopt);
}
default:
throw exceptions::invalid_request_exception(std::string("lua execution failed: ") +
lua_tostring(l, -1));
}
});
}
namespace lua {
void register_metatables(lua_State* l) {
luaL_newmetatable(l, scylla_decimal_metatable_name);
lua_pushvalue(l, -1);
lua_setfield(l, -2, "__index");
luaL_setfuncs(l, decimal_methods, 0);
lua_pop(l, 1);
}
void push_data_value(lua_State* l, const data_value& value) {
push_argument(l, value);
}
data_value pop_data_value(lua_State* l, const data_type& type) {
return convert_from_lua(l, type);
}
}