Switch vector dimension handling to fixed-width `uint32_t` type, update parsing/validation, and add boundary tests. The dimension is parsed as `unsigned long` at first which is guaranteed to be **at least** 32-bit long, which is safe to downcast to `uint32_t`. Move `MAX_VECTOR_DIMENSION` from `cql3_type::raw_vector` to `cql3_type` to ensure public visibility for checks outside the class. Add tests to verify the type boundaries. Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-223 Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com> Co-authored-by: Dawid Pawlik <dawid.pawlik@scylladb.com> Closes scylladb/scylladb#28762
1154 lines
41 KiB
C++
1154 lines
41 KiB
C++
/*
|
|
* Copyright (C) 2019-present ScyllaDB
|
|
*/
|
|
|
|
/*
|
|
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
*/
|
|
|
|
#include <boost/date_time/gregorian/greg_date.hpp>
|
|
#include <boost/date_time/posix_time/posix_time.hpp>
|
|
#include <random>
|
|
#include "lua.hh"
|
|
#include "lang/lua_scylla_types.hh"
|
|
#include "exceptions/exceptions.hh"
|
|
#include "types/concrete_types.hh"
|
|
#include "utils/assert.hh"
|
|
#include "utils/utf8.hh"
|
|
#include "utils/ascii.hh"
|
|
#include "utils/date.h"
|
|
#include <seastar/core/align.hh>
|
|
#include <lua.hpp>
|
|
#include "seastarx.hh"
|
|
|
|
// Lua 5.4 added an extra parameter to lua_resume
|
|
|
|
#if LUA_VERSION_NUM >= 504
|
|
# define LUA_504_PLUS(x...) x
|
|
#else
|
|
# define LUA_504_PLUS(x...)
|
|
#endif
|
|
|
|
// Lua 5.5 added a seed parameter to lua_newstate
|
|
|
|
#if LUA_VERSION_NUM >= 505
|
|
# define LUA_505_PLUS(x...) x
|
|
#else
|
|
# define LUA_505_PLUS(x...)
|
|
#endif
|
|
|
|
using namespace seastar;
|
|
using namespace lua;
|
|
|
|
static logging::logger lua_logger("lua");
|
|
|
|
namespace {
|
|
struct alloc_state {
|
|
size_t allocated = 0;
|
|
size_t max;
|
|
size_t max_contiguous;
|
|
alloc_state(size_t max, size_t max_contiguous)
|
|
: max(max)
|
|
, max_contiguous(max_contiguous) {
|
|
// The max and max_contiguous limits are responsible for avoiding overflows.
|
|
SCYLLA_ASSERT(max + max_contiguous >= max);
|
|
}
|
|
};
|
|
|
|
struct lua_closer {
|
|
void operator()(lua_State* l) {
|
|
lua_close(l);
|
|
}
|
|
};
|
|
|
|
static const char scylla_decimal_metatable_name[] = "Scylla.decimal";
|
|
|
|
class lua_slice_state {
|
|
std::unique_ptr<alloc_state> a_state;
|
|
std::unique_ptr<lua_State, lua_closer> _l;
|
|
public:
|
|
lua_slice_state(std::unique_ptr<alloc_state> a_state, std::unique_ptr<lua_State, lua_closer> l)
|
|
: a_state(std::move(a_state))
|
|
, _l(std::move(l)) {}
|
|
operator lua_State*() { return _l.get(); }
|
|
};
|
|
}
|
|
|
|
static void* lua_alloc(void* ud, void* ptr, size_t osize, size_t nsize) {
|
|
auto* s = reinterpret_cast<alloc_state*>(ud);
|
|
|
|
// avoid realloc(nullptr, 0), which would allocate.
|
|
if (nsize == 0 && ptr == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (nsize > s->max_contiguous) {
|
|
return nullptr;
|
|
}
|
|
|
|
size_t next = s->allocated + nsize;
|
|
|
|
// The max and max_contiguous limits should be small enough to avoid overflows.
|
|
SCYLLA_ASSERT(next >= s->allocated);
|
|
|
|
if (ptr) {
|
|
next -= osize;
|
|
}
|
|
|
|
if (next > s->max) {
|
|
lua_logger.info("allocation failed. already allocated = {}, next total = {}, max = {}", s->allocated, next, s->max);
|
|
return nullptr;
|
|
}
|
|
|
|
// FIXME: Given that we have osize, we can probably do better when
|
|
// SEASTAR_DEFAULT_ALLOCATOR is false
|
|
void* ret = realloc(ptr, nsize);
|
|
|
|
if (nsize == 0 || ret != nullptr) {
|
|
s->allocated = next;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
static const luaL_Reg loadedlibs[] = {
|
|
{"_G", luaopen_base},
|
|
{LUA_STRLIBNAME, luaopen_string},
|
|
{LUA_COLIBNAME, luaopen_coroutine},
|
|
{LUA_TABLIBNAME, luaopen_table},
|
|
{NULL, NULL},
|
|
};
|
|
|
|
static void debug_hook(lua_State* l, lua_Debug* ar) {
|
|
if (!need_preempt()) {
|
|
return;
|
|
}
|
|
// The lua manual says that only count and line events can yield. Of those we only use count.
|
|
if (ar->event != LUA_HOOKCOUNT) {
|
|
// Set the hook to stop at the very next lua instruction, where we will be able to yield.
|
|
lua_sethook(l, debug_hook, LUA_MASKCOUNT, 1);
|
|
return;
|
|
}
|
|
if (lua_yield(l, 0)) {
|
|
SCYLLA_ASSERT(0 && "lua_yield failed");
|
|
}
|
|
}
|
|
|
|
static lua_slice_state new_lua(const lua::runtime_config& cfg) {
|
|
auto a_state = std::make_unique<alloc_state>(cfg.max_bytes, cfg.max_contiguous);
|
|
#if LUA_VERSION_NUM >= 505
|
|
static thread_local std::default_random_engine rng{std::random_device{}()};
|
|
auto seed = rng();
|
|
#endif
|
|
std::unique_ptr<lua_State, lua_closer> l{lua_newstate(lua_alloc, a_state.get() LUA_505_PLUS(, seed))};
|
|
if (!l) {
|
|
throw std::runtime_error("could not create lua state");
|
|
}
|
|
return lua_slice_state{std::move(a_state), std::move(l)};
|
|
}
|
|
|
|
static int string_writer(lua_State*, const void* p, size_t size, void* data) {
|
|
luaL_addlstring(reinterpret_cast<luaL_Buffer*>(data), reinterpret_cast<const char*>(p), size);
|
|
return 0;
|
|
}
|
|
|
|
static int compile_l(lua_State* l) {
|
|
const auto& script = *reinterpret_cast<sstring*>(lua_touserdata(l, 1));
|
|
luaL_Buffer buf;
|
|
luaL_buffinit(l, &buf);
|
|
|
|
if (luaL_loadbufferx(l, script.c_str(), script.size(), "<internal>", "t")) {
|
|
lua_error(l);
|
|
}
|
|
if (lua_dump(l, string_writer, &buf, true)) {
|
|
luaL_error(l, "lua_dump failed");
|
|
}
|
|
luaL_pushresult(&buf);
|
|
|
|
return 1;
|
|
}
|
|
|
|
sstring lua::compile(const runtime_config& cfg, const std::vector<sstring>& arg_names, sstring script) {
|
|
if (!arg_names.empty()) {
|
|
// In Lua, all chunks are compiled to vararg functions. To use
|
|
// the UDF argument names, we start the chunk with "local
|
|
// arg1,arg2,etc = ...", which captures the arguments when the
|
|
// chunk is called.
|
|
std::ostringstream os;
|
|
os << "local ";
|
|
for (int i = 0, n = arg_names.size(); i < n; ++i) {
|
|
if (i != 0) {
|
|
os << ",";
|
|
}
|
|
os << arg_names[i];
|
|
}
|
|
os << " = ...;\n" << script;
|
|
script = os.str();
|
|
}
|
|
lua_slice_state l = new_lua(cfg);
|
|
|
|
// Run the load from lua_pcall so we don't have to handle longjmp.
|
|
lua_pushcfunction(l, compile_l);
|
|
lua_pushlightuserdata(l, &script);
|
|
if (lua_pcall(l, 1, 1, 0)) {
|
|
throw exceptions::invalid_request_exception(std::string("could not compile: ") + lua_tostring(l, -1));
|
|
}
|
|
|
|
size_t len;
|
|
const char* p = lua_tolstring(l, -1, &len);
|
|
return sstring(p, len);
|
|
}
|
|
|
|
static big_decimal* get_decimal(lua_State* l, int arg) {
|
|
constexpr size_t alignment = alignof(big_decimal);
|
|
char* p = reinterpret_cast<char*>(luaL_testudata(l, arg, scylla_decimal_metatable_name));
|
|
if (p) {
|
|
return reinterpret_cast<big_decimal*>(align_up(p, alignment));
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
static void push_big_decimal(lua_State* l, const big_decimal& v) {
|
|
auto* p = aligned_user_data<big_decimal>(l);
|
|
new (p) big_decimal(v);
|
|
luaL_setmetatable(l, scylla_decimal_metatable_name);
|
|
}
|
|
|
|
static void push_cpp_int(lua_State* l, const utils::multiprecision_int& v) {
|
|
push_big_decimal(l, big_decimal(0, v));
|
|
}
|
|
|
|
struct lua_table {
|
|
};
|
|
|
|
template <typename Func>
|
|
using lua_visit_ret_type = std::invoke_result_t<Func, const double&>;
|
|
|
|
template <typename Func>
|
|
concept CanHandleRawLuaTypes = requires(Func f) {
|
|
{ f(*static_cast<const long long*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const double*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const big_decimal*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const std::string_view*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const lua_table*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
};
|
|
|
|
template <typename Func>
|
|
requires CanHandleRawLuaTypes<Func>
|
|
static auto visit_lua_raw_value(lua_State* l, int index, Func&& f) {
|
|
switch (lua_type(l, index)) {
|
|
case LUA_TNONE:
|
|
SCYLLA_ASSERT(0 && "Invalid index");
|
|
case LUA_TNUMBER:
|
|
if (lua_isinteger(l, index)) {
|
|
return f(lua_tointeger(l, index));
|
|
}
|
|
return f(lua_tonumber(l, index));
|
|
case LUA_TSTRING: {
|
|
size_t len;
|
|
const char* s = lua_tolstring(l, index, &len);
|
|
return f(std::string_view{s, len});
|
|
}
|
|
case LUA_TTABLE:
|
|
return f(lua_table{});
|
|
case LUA_TBOOLEAN:
|
|
case LUA_TFUNCTION:
|
|
case LUA_TNIL:
|
|
throw exceptions::invalid_request_exception("unexpected value");
|
|
case LUA_TUSERDATA:
|
|
return f(*get_decimal(l, index));
|
|
case LUA_TTHREAD:
|
|
case LUA_TLIGHTUSERDATA:
|
|
SCYLLA_ASSERT(0 && "We never make thread or light user data visible to scripts");
|
|
}
|
|
SCYLLA_ASSERT(0 && "invalid lua type");
|
|
}
|
|
|
|
template <typename Func>
|
|
static auto visit_decimal(const big_decimal &v, Func&& f) {
|
|
boost::multiprecision::cpp_rational r = v.as_rational();
|
|
const boost::multiprecision::cpp_int& dividend = numerator(r);
|
|
const boost::multiprecision::cpp_int& divisor = denominator(r);
|
|
if (dividend % divisor == 0) {
|
|
return f(utils::multiprecision_int(dividend/divisor));
|
|
}
|
|
return f(r.convert_to<double>());
|
|
}
|
|
|
|
template <typename Func>
|
|
concept CanHandleLuaTypes = requires(Func f) {
|
|
{ f(*static_cast<const double*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const utils::multiprecision_int*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const big_decimal*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const std::string_view*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
{ f(*static_cast<const lua_table*>(nullptr)) } -> std::same_as<lua_visit_ret_type<Func>>;
|
|
};
|
|
|
|
template <typename Func>
|
|
requires CanHandleLuaTypes<Func>
|
|
static auto visit_lua_value(lua_State* l, int index, Func&& f) {
|
|
struct visitor {
|
|
lua_State* l;
|
|
int index;
|
|
Func& f;
|
|
auto operator()(const long long& v) { return f(utils::multiprecision_int(v)); }
|
|
auto operator()(const utils::multiprecision_int& v) { return f(v); }
|
|
auto operator()(const double& v) {
|
|
auto min = double(std::numeric_limits<long long>::min());
|
|
auto max = double(std::numeric_limits<long long>::max());
|
|
if (min <= v && v <= max && std::trunc(v) == v) {
|
|
return (*this)((long long)v);
|
|
}
|
|
// FIXME: We could use frexp to produce a decimal instead of a double
|
|
return f(v);
|
|
}
|
|
auto operator()(const std::string_view& v) {
|
|
big_decimal v2;
|
|
try {
|
|
v2 = big_decimal(v);
|
|
} catch (marshal_exception&) {
|
|
// The string is not a valid big_decimal. Let Lua try to convert it to a double.
|
|
int isnum;
|
|
double d = lua_tonumberx(l, index, &isnum);
|
|
if (isnum) {
|
|
return (*this)(d);
|
|
}
|
|
return f(v);
|
|
}
|
|
return (*this)(v2);
|
|
}
|
|
auto operator()(const big_decimal& v) {
|
|
struct visitor {
|
|
Func& f;
|
|
const big_decimal &d;
|
|
auto operator()(const double&) { return f(d); }
|
|
auto operator()(const utils::multiprecision_int& v) { return f(v); }
|
|
};
|
|
return visit_decimal(v, visitor{f, v});
|
|
}
|
|
auto operator()(const lua_table& v) {
|
|
return f(v);
|
|
}
|
|
};
|
|
return visit_lua_raw_value(l, index, visitor{l, index, f});
|
|
}
|
|
|
|
template <typename Func>
|
|
static auto visit_lua_number(lua_State* l, int index, Func&& f) {
|
|
return visit_lua_value(l, index, make_visitor(
|
|
[] (const std::string_view& v) -> std::invoke_result_t<Func, double> {
|
|
throw exceptions::invalid_request_exception("value is not a number");
|
|
},
|
|
[] (const lua_table&) -> std::invoke_result_t<Func, double> {
|
|
throw exceptions::invalid_request_exception("value is not a number");
|
|
},
|
|
std::forward<Func>(f)
|
|
));
|
|
}
|
|
|
|
template <typename Func> static auto visit_lua_decimal(lua_State* l, int index, Func&& f) {
|
|
return visit_lua_number(l, index, make_visitor(
|
|
[&f](const utils::multiprecision_int& v) { return f(big_decimal(0, v)); },
|
|
[&f](const auto& v) { return f(v); }
|
|
));
|
|
}
|
|
|
|
static int decimal_gc(lua_State *l) {
|
|
std::destroy_at(get_decimal(l, 1));
|
|
return 0;
|
|
}
|
|
|
|
static double decimal_to_double(const big_decimal &d) {
|
|
return visit_decimal(d, [] (auto&& v) { return double(v); });
|
|
}
|
|
|
|
static const big_decimal& get_decimal_in_binary_op(lua_State* l) {
|
|
auto* a = get_decimal(l, 1);
|
|
if (a == nullptr) {
|
|
lua_insert(l, 1);
|
|
a = get_decimal(l, 1);
|
|
SCYLLA_ASSERT(a);
|
|
}
|
|
return *a;
|
|
}
|
|
|
|
template<typename Func>
|
|
static void visit_decimal_bin_op(lua_State* l, Func&& F) {
|
|
const auto& a = get_decimal_in_binary_op(l);
|
|
struct bin_op_visitor {
|
|
const big_decimal& a;
|
|
Func& F;
|
|
lua_State* l;
|
|
void operator()(const double& b) {
|
|
lua_pushnumber(l, F(decimal_to_double(a), b));
|
|
}
|
|
void operator()(const big_decimal& b) {
|
|
push_big_decimal(l, F(a, b));
|
|
}
|
|
};
|
|
|
|
visit_lua_decimal(l, -1, bin_op_visitor{a, F, l});
|
|
}
|
|
|
|
static int decimal_add(lua_State* l) {
|
|
visit_decimal_bin_op(l, [](auto&& a, auto&& b) { return a + b; });
|
|
return 1;
|
|
}
|
|
|
|
static int decimal_sub(lua_State* l) {
|
|
visit_decimal_bin_op(l, [](auto&& a, auto&& b) { return a - b; });
|
|
return 1;
|
|
}
|
|
|
|
static const struct luaL_Reg decimal_methods[] {
|
|
{"__gc", decimal_gc},
|
|
{"__add", decimal_add},
|
|
{"__sub", decimal_sub},
|
|
{nullptr, nullptr}
|
|
};
|
|
|
|
static int load_script_l(lua_State* l) {
|
|
const auto& bitcode = *reinterpret_cast<lua::bitcode_view*>(lua_touserdata(l, 1));
|
|
const auto& binary = bitcode.bitcode;
|
|
|
|
for (const luaL_Reg* lib = loadedlibs; lib->func; lib++) {
|
|
luaL_requiref(l, lib->name, lib->func, 1);
|
|
lua_pop(l, 1);
|
|
}
|
|
|
|
lua::register_metatables(l);
|
|
|
|
if (luaL_loadbufferx(l, binary.data(), binary.size(), "<internal>", "b")) {
|
|
lua_error(l);
|
|
}
|
|
|
|
return 1;
|
|
}
|
|
|
|
static lua_slice_state load_script(const lua::runtime_config& cfg, lua::bitcode_view binary) {
|
|
lua_slice_state l = new_lua(cfg);
|
|
|
|
// Run the initialization from lua_pcall so we don't have to
|
|
// handle longjmp. We know that a new state has a few reserved
|
|
// stack slots and the following push calls don't allocate.
|
|
lua_pushcfunction(l, load_script_l);
|
|
lua_pushlightuserdata(l, &binary);
|
|
if (lua_pcall(l, 1, 1, 0)) {
|
|
throw std::runtime_error(std::string("could not initiate: ") + lua_tostring(l, -1));
|
|
}
|
|
|
|
return l;
|
|
}
|
|
|
|
using millisecond = std::chrono::duration<double, std::milli>;
|
|
static auto now() { return std::chrono::system_clock::now(); }
|
|
|
|
static utils::multiprecision_int get_varint(lua_State* l, int index) {
|
|
return visit_lua_number(l, index, make_visitor(
|
|
[](const utils::multiprecision_int& v) { return v; },
|
|
[](const auto& v) -> utils::multiprecision_int{
|
|
throw exceptions::invalid_request_exception("value is not an integer");
|
|
}
|
|
));
|
|
}
|
|
|
|
static sstring get_string(lua_State *l, int index) {
|
|
return visit_lua_value(l, index, make_visitor(
|
|
[] (const lua_table&) -> sstring {
|
|
throw exceptions::invalid_request_exception("unexpected value");
|
|
},
|
|
[] (const utils::multiprecision_int& p) {
|
|
return sstring(p.str());
|
|
},
|
|
[] (const auto& v) {
|
|
return seastar::format("{}", v);
|
|
}));
|
|
}
|
|
|
|
static data_value convert_from_lua(lua_State* l, const data_type& type);
|
|
|
|
namespace {
|
|
struct lua_date_table {
|
|
// The lua date table is documented at https://www.lua.org/pil/22.1.html
|
|
// date::year uses a int64_t, but there is no reason to try to
|
|
// support 64 bit years. In practice the limitations are
|
|
// * year_month_day::to_days hits a signed integer overflow for
|
|
// large years.
|
|
// * boost::gregorian only supports the years [1400,9999]
|
|
int32_t year;
|
|
|
|
// Both date::month and date::day use unsigned char.
|
|
unsigned char month;
|
|
unsigned char day;
|
|
std::optional<int32_t> hour;
|
|
std::optional<int32_t> minute;
|
|
std::optional<int32_t> second;
|
|
};
|
|
|
|
static lua_date_table get_lua_date_table(lua_State* l, int index) {
|
|
std::optional<int32_t> year;
|
|
std::optional<unsigned char> month;
|
|
std::optional<unsigned char> day;
|
|
std::optional<int32_t> hour;
|
|
std::optional<int32_t> minute;
|
|
std::optional<int32_t> second;
|
|
|
|
lua_pushnil(l);
|
|
while (lua_next(l, index - 1) != 0) {
|
|
auto k = get_string(l, index - 1);
|
|
auto v = get_varint(l, index);
|
|
lua_pop(l, 1);
|
|
if (k == "month") {
|
|
month = (unsigned char)v;
|
|
if (*month != v) {
|
|
throw exceptions::invalid_request_exception(seastar::format("month is too large: '{}'", v.str()));
|
|
}
|
|
} else if (k == "day") {
|
|
day = (unsigned char)v;
|
|
if (*day != v) {
|
|
throw exceptions::invalid_request_exception(seastar::format("day is too large: '{}'", v.str()));
|
|
}
|
|
} else {
|
|
int32_t vint(v);
|
|
if (vint != v) {
|
|
throw exceptions::invalid_request_exception(seastar::format("{} is too large: '{}'", k, v.str()));
|
|
}
|
|
if (k == "year") {
|
|
year = vint;
|
|
} else if (k == "hour") {
|
|
hour = vint;
|
|
} else if (k == "min") {
|
|
minute = vint;
|
|
} else if (k == "sec") {
|
|
second = vint;
|
|
} else {
|
|
throw exceptions::invalid_request_exception(format("invalid date table field: '{}'", k));
|
|
}
|
|
}
|
|
}
|
|
if (!year || !month || !day) {
|
|
throw exceptions::invalid_request_exception("date table must have year, month and day");
|
|
}
|
|
return lua_date_table{*year, *month, *day, hour, minute, second};
|
|
}
|
|
|
|
struct simple_date_return_visitor {
|
|
lua_State* l;
|
|
template <typename T>
|
|
uint32_t operator()(const T&) {
|
|
throw exceptions::invalid_request_exception("date must be a string, integer or date table");
|
|
}
|
|
uint32_t operator()(const utils::multiprecision_int& v) {
|
|
if (v > std::numeric_limits<uint32_t>::max()) {
|
|
throw exceptions::invalid_request_exception("date value must fit in 32 bits");
|
|
}
|
|
return uint32_t(v);
|
|
}
|
|
uint32_t operator()(const std::string_view& v) {
|
|
return simple_date_type_impl::from_string_view(v);
|
|
}
|
|
uint32_t operator()(const lua_table&);
|
|
};
|
|
|
|
struct timestamp_return_visitor {
|
|
lua_State* l;
|
|
template <typename T>
|
|
db_clock::time_point operator()(const T&) {
|
|
throw exceptions::invalid_request_exception("timestamp must be a string, integer or date table");
|
|
}
|
|
db_clock::time_point operator()(const utils::multiprecision_int& v) {
|
|
int64_t v2 = int64_t(v);
|
|
if (v2 == v) {
|
|
return db_clock::time_point(db_clock::duration(v2));
|
|
}
|
|
throw exceptions::invalid_request_exception("timestamp value must fit in signed 64 bits");
|
|
}
|
|
db_clock::time_point operator()(const std::string_view& v) {
|
|
return timestamp_type_impl::from_string_view(v);
|
|
}
|
|
db_clock::time_point operator()(const lua_table&);
|
|
};
|
|
|
|
struct from_lua_visitor {
|
|
lua_State* l;
|
|
|
|
data_value operator()(const reversed_type_impl& t) {
|
|
// This is unreachable since reversed_type_impl is used only
|
|
// in the tables. The function return the underlying type.
|
|
abort();
|
|
}
|
|
|
|
data_value operator()(const empty_type_impl& t) {
|
|
// This is unreachable since empty types are not user visible.
|
|
abort();
|
|
}
|
|
|
|
data_value operator()(const decimal_type_impl& t) {
|
|
struct visitor {
|
|
big_decimal operator()(const double& b) {
|
|
throw exceptions::invalid_request_exception("value is not a decimal");
|
|
}
|
|
big_decimal operator()(const big_decimal& b) {
|
|
return b;
|
|
}
|
|
};
|
|
return visit_lua_decimal(l, -1, visitor{});
|
|
}
|
|
|
|
data_value operator()(const varint_type_impl& t) {
|
|
return get_varint(l, -1);
|
|
}
|
|
|
|
data_value operator()(const duration_type_impl& t) {
|
|
return visit_lua_value(l, -1, make_visitor(
|
|
[] (const auto&) -> cql_duration {
|
|
throw exceptions::invalid_request_exception("a duration must be of the form { months = v1, days = v2, nanoseconds = v3 }");
|
|
},
|
|
[] (const std::string_view& v) {
|
|
return cql_duration(v);
|
|
},
|
|
[this] (const lua_table&) {
|
|
int32_t months = 0;
|
|
int32_t days = 0;
|
|
int64_t nanoseconds = 0;
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto k = get_string(l, -2);
|
|
auto v = get_varint(l, -1);
|
|
lua_pop(l, 1);
|
|
if (k == "months") {
|
|
months = int32_t(v);
|
|
if (v != months) {
|
|
throw exceptions::invalid_request_exception(seastar::format("{} months doesn't fit in a 32 bit integer", v.str()));
|
|
}
|
|
} else if (k == "days") {
|
|
days = int32_t(v);
|
|
if (v != days) {
|
|
throw exceptions::invalid_request_exception(seastar::format("{} days doesn't fit in a 32 bit integer", v.str()));
|
|
}
|
|
} else if (k == "nanoseconds") {
|
|
nanoseconds = int64_t(v);
|
|
if (v != nanoseconds) {
|
|
throw exceptions::invalid_request_exception(seastar::format("{} nanoseconds doesn't fit in a 64 bit integer", v.str()));
|
|
}
|
|
} else {
|
|
throw exceptions::invalid_request_exception(format("invalid duration field: '{}'", k));
|
|
}
|
|
}
|
|
return cql_duration(months_counter(months), days_counter(days), nanoseconds_counter(nanoseconds));
|
|
}));
|
|
}
|
|
|
|
data_value operator()(const set_type_impl& t) {
|
|
std::vector<data_value> elements;
|
|
const data_type& element_type = t.get_elements_type();
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
if (!lua_toboolean(l, -1)) {
|
|
throw exceptions::invalid_request_exception("sets are represented with tables with true values");
|
|
}
|
|
lua_pop(l, 1);
|
|
elements.push_back(convert_from_lua(l, element_type));
|
|
}
|
|
std::sort(elements.begin(), elements.end(), [&](const data_value& a, const data_value& b) {
|
|
// FIXME: this is madness, we have to be able to compare without serializing!
|
|
return element_type->less(a.serialize_nonnull(), b.serialize_nonnull());
|
|
});
|
|
return make_set_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const map_type_impl& t) {
|
|
const data_type& key_type = t.get_keys_type();
|
|
const data_type& value_type = t.get_values_type();
|
|
using map_pair = std::pair<data_value, data_value>;
|
|
std::vector<map_pair> elements;
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto v = convert_from_lua(l, value_type);
|
|
lua_pop(l, 1);
|
|
auto k = convert_from_lua(l, key_type);
|
|
elements.push_back({k, v});
|
|
}
|
|
std::sort(elements.begin(), elements.end(), [&](const map_pair& a, const map_pair& b) {
|
|
// FIXME: this is madness, we have to be able to compare without serializing!
|
|
return key_type->less(a.first.serialize_nonnull(), b.first.serialize_nonnull());
|
|
});
|
|
return make_map_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const list_type_impl& t) {
|
|
if (!lua_istable(l, -1)) {
|
|
throw exceptions::invalid_request_exception("value is not a table");
|
|
}
|
|
|
|
const data_type& elements_type = t.get_elements_type();
|
|
using table_pair = std::pair<utils::multiprecision_int, data_value>;
|
|
std::vector<table_pair> pairs;
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto v = convert_from_lua(l, elements_type);
|
|
lua_pop(l, 1);
|
|
pairs.push_back({get_varint(l, -1), v});
|
|
}
|
|
|
|
std::sort(pairs.begin(), pairs.end(), [] (const table_pair& a, const table_pair& b) {
|
|
return a.first < b.first;
|
|
});
|
|
|
|
size_t num_elements = pairs.size();
|
|
std::vector<data_value> elements;
|
|
for (size_t i = 0; i < num_elements; ++i) {
|
|
if (utils::multiprecision_int(i + 1) != pairs[i].first) {
|
|
throw exceptions::invalid_request_exception("table is not a sequence");
|
|
}
|
|
elements.push_back(pairs[i].second);
|
|
}
|
|
return make_list_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const tuple_type_impl& t) {
|
|
if (!lua_istable(l, -1)) {
|
|
throw exceptions::invalid_request_exception("value is not a table");
|
|
}
|
|
|
|
size_t num_elements = t.size();
|
|
std::vector<std::optional<data_value>> opt_elements(num_elements);
|
|
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto k_varint = get_varint(l, -2);
|
|
if (k_varint > num_elements || k_varint < 1) {
|
|
throw exceptions::invalid_request_exception(
|
|
seastar::format("key {} is not valid for a sequence of size {}", k_varint.str(), num_elements));
|
|
}
|
|
size_t k = size_t(k_varint);
|
|
opt_elements[k - 1] = convert_from_lua(l, t.type(k - 1));
|
|
lua_pop(l, 1);
|
|
}
|
|
|
|
std::vector<data_value> elements;
|
|
elements.reserve(num_elements);
|
|
for (size_t i = 0; i < num_elements; ++i) {
|
|
if (!opt_elements[i]) {
|
|
throw exceptions::invalid_request_exception(
|
|
format("key {} missing in sequence of size {}", i + 1, num_elements));
|
|
}
|
|
elements.push_back(*opt_elements[i]);
|
|
}
|
|
return make_tuple_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const vector_type_impl& t) {
|
|
if (!lua_istable(l, -1)) {
|
|
throw exceptions::invalid_request_exception("value is not a table");
|
|
}
|
|
|
|
const data_type& elements_type = t.get_elements_type();
|
|
vector_dimension_t num_elements = t.get_dimension();
|
|
|
|
using table_pair = std::pair<utils::multiprecision_int, data_value>;
|
|
std::vector<table_pair> pairs;
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto v = convert_from_lua(l, elements_type);
|
|
lua_pop(l, 1);
|
|
auto k = get_varint(l, -1);
|
|
if (k > num_elements || k < 1) {
|
|
throw exceptions::invalid_request_exception(
|
|
seastar::format("key {} is not valid for a sequence of size {}", k.str(), num_elements));
|
|
}
|
|
pairs.push_back({k, v});
|
|
}
|
|
|
|
std::sort(pairs.begin(), pairs.end(), [] (const table_pair& a, const table_pair& b) {
|
|
return a.first < b.first;
|
|
});
|
|
|
|
std::vector<data_value> elements;
|
|
for (size_t i = 0; i < num_elements; ++i) {
|
|
if (utils::multiprecision_int(i + 1) != pairs[i].first) {
|
|
throw exceptions::invalid_request_exception(
|
|
format("key {} missing in sequence of size {}", i + 1, num_elements));
|
|
}
|
|
elements.push_back(pairs[i].second);
|
|
}
|
|
return make_vector_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const user_type_impl& t) {
|
|
size_t num_fields = t.field_types().size();
|
|
|
|
std::unordered_map<sstring, std::pair<unsigned, data_type>> field_types;
|
|
field_types.reserve(num_fields);
|
|
for (unsigned i = 0; i < num_fields; ++i) {
|
|
field_types.insert({t.field_name_as_string(i), {i, t.field_type(i)}});
|
|
}
|
|
|
|
std::vector<std::optional<data_value>> opt_elements(num_fields);
|
|
lua_pushnil(l);
|
|
while (lua_next(l, -2) != 0) {
|
|
auto s = get_string(l, -2);
|
|
auto iter = field_types.find(s);
|
|
if (iter == field_types.end()) {
|
|
throw exceptions::invalid_request_exception(format("invalid UDT field '{}'", s));
|
|
}
|
|
|
|
const auto &p = iter->second;
|
|
auto v = convert_from_lua(l, p.second);
|
|
lua_pop(l, 1);
|
|
|
|
opt_elements[p.first] = std::move(v);
|
|
}
|
|
|
|
std::vector<data_value> elements;
|
|
elements.reserve(num_fields);
|
|
for (size_t i = 0; i < num_fields; ++i) {
|
|
if (!opt_elements[i]) {
|
|
throw exceptions::invalid_request_exception(
|
|
format("key {} missing in udt {}", t.field_name_as_string(i), t.get_name_as_string()));
|
|
}
|
|
elements.push_back(*opt_elements[i]);
|
|
}
|
|
|
|
return make_user_value(t.shared_from_this(), std::move(elements));
|
|
}
|
|
|
|
data_value operator()(const inet_addr_type_impl& t) {
|
|
return t.from_string_view(get_string(l, -1));
|
|
}
|
|
|
|
data_value operator()(const uuid_type_impl&) {
|
|
return uuid_type_impl::from_string_view(get_string(l, -1));
|
|
}
|
|
|
|
data_value operator()(const timeuuid_type_impl&) {
|
|
return timeuuid_native_type{timeuuid_type_impl::from_string_view(get_string(l, -1))};
|
|
}
|
|
|
|
data_value operator()(const bytes_type_impl& t) {
|
|
sstring v = get_string(l, -1);
|
|
return data_value(bytes(reinterpret_cast<const int8_t*>(v.data()), v.size()));
|
|
}
|
|
|
|
data_value operator()(const utf8_type_impl& t) {
|
|
sstring s = get_string(l, -1);
|
|
auto error_pos = utils::utf8::validate_with_error_position(reinterpret_cast<uint8_t*>(s.data()), s.size());
|
|
if (error_pos) {
|
|
throw exceptions::invalid_request_exception(format("value is not valid utf8, invalid character at byte offset {}", *error_pos));
|
|
}
|
|
return s;
|
|
}
|
|
|
|
data_value operator()(const ascii_type_impl& t) {
|
|
sstring s = get_string(l, -1);
|
|
if (utils::ascii::validate(reinterpret_cast<uint8_t*>(s.data()), s.size())) {
|
|
return ascii_native_type{std::move(s)};
|
|
}
|
|
throw exceptions::invalid_request_exception("value is not valid ascii");
|
|
}
|
|
|
|
data_value operator()(const boolean_type_impl& t) {
|
|
return bool(lua_toboolean(l, -1));
|
|
}
|
|
|
|
template <typename T> data_value operator()(const floating_type_impl<T>& t) {
|
|
return visit_lua_number(l, -1, make_visitor(
|
|
[] (const big_decimal& v) -> T { return decimal_to_double(v); },
|
|
[] (const auto& v) { return T(v); }
|
|
));
|
|
}
|
|
|
|
int64_t get_integer() {
|
|
return from_varint_to_integer(get_varint(l, -1));
|
|
}
|
|
|
|
data_value operator()(const timestamp_date_base_class& t) {
|
|
return visit_lua_value(l, -1, timestamp_return_visitor{l});
|
|
}
|
|
|
|
data_value operator()(const time_type_impl& t) {
|
|
return time_native_type{visit_lua_value(l, -1, make_visitor(
|
|
[] (const auto&) -> int64_t {
|
|
throw exceptions::invalid_request_exception("time must be a string or an integer");
|
|
},
|
|
[] (const utils::multiprecision_int& v) {
|
|
int64_t v2 = int64_t(v);
|
|
if (v2 == v) {
|
|
return v2;
|
|
}
|
|
throw exceptions::invalid_request_exception("time value must fit in signed 64 bits");
|
|
},
|
|
[] (const std::string_view& v) {
|
|
return time_type_impl::from_string_view(v);
|
|
}
|
|
))};
|
|
}
|
|
|
|
data_value operator()(const counter_type_impl&) {
|
|
// No data_value ever has a counter type, it is represented
|
|
// with long_type instead.
|
|
return get_integer();
|
|
}
|
|
|
|
template <typename T> data_value operator()(const integer_type_impl<T>& t) {
|
|
return T(get_integer());
|
|
}
|
|
|
|
data_value operator()(const simple_date_type_impl& t) {
|
|
return simple_date_native_type{visit_lua_value(l, -1, simple_date_return_visitor{l})};
|
|
}
|
|
};
|
|
|
|
uint32_t simple_date_return_visitor::operator()(const lua_table&) {
|
|
auto table = get_lua_date_table(l, -1);
|
|
if (table.hour || table.minute || table.second) {
|
|
throw exceptions::invalid_request_exception("date type has no hour, minute or second");
|
|
}
|
|
date::year_month_day ymd{date::year{table.year}, date::month{table.month}, date::day{table.day}};
|
|
int64_t days = date::local_days(ymd).time_since_epoch().count() + (1UL << 31);
|
|
return (*this)(utils::multiprecision_int(days));
|
|
}
|
|
|
|
db_clock::time_point timestamp_return_visitor::operator()(const lua_table&) {
|
|
auto table = get_lua_date_table(l, -1);
|
|
boost::gregorian::date date(table.year, table.month, table.day);
|
|
boost::posix_time::time_duration time(table.hour.value_or(12), table.minute.value_or(0), table.second.value_or(0));
|
|
boost::posix_time::ptime timestamp(date, time);
|
|
int64_t msec = (timestamp - boost::posix_time::from_time_t(0)).total_milliseconds();
|
|
return (*this)(utils::multiprecision_int(msec));
|
|
}
|
|
}
|
|
|
|
static data_value convert_from_lua(lua_State* l, const data_type& type) {
|
|
if (lua_isnil(l, -1)) {
|
|
return data_value::make_null(type);
|
|
}
|
|
return ::visit(*type, from_lua_visitor{l});
|
|
}
|
|
|
|
static bytes_opt convert_return(lua_slice_state &l, const data_type& return_type) {
|
|
int num_return_vals = lua_gettop(l);
|
|
if (num_return_vals != 1) {
|
|
throw exceptions::invalid_request_exception(
|
|
format("{} values returned, expected {}", num_return_vals, 1));
|
|
}
|
|
|
|
// FIXME: It should be possible to avoid creating the data_value,
|
|
// or even better, change the function::execute interface to
|
|
// return a data_value instead of bytes_opt.
|
|
return convert_from_lua(l, return_type).serialize();
|
|
}
|
|
|
|
void lua::push_sstring(lua_State* l, const sstring& v) {
|
|
lua_pushlstring(l, v.c_str(), v.size());
|
|
}
|
|
|
|
static void push_argument(lua_State* l, const data_value& arg);
|
|
|
|
namespace {
|
|
struct to_lua_visitor {
|
|
lua_State* l;
|
|
|
|
void operator()(const varint_type_impl& t, const emptyable<utils::multiprecision_int>* v) {
|
|
push_cpp_int(l, *v);
|
|
}
|
|
|
|
void operator()(const decimal_type_impl& t, const emptyable<big_decimal>* v) {
|
|
push_big_decimal(l, *v);
|
|
}
|
|
|
|
void operator()(const counter_type_impl& t, const void* v) {
|
|
// This is unreachable since deserialize_visitor for
|
|
// counter_type_impl return a long.
|
|
abort();
|
|
}
|
|
|
|
void operator()(const empty_type_impl& t, const void* v) {
|
|
// This is unreachable since empty types are not user visible.
|
|
abort();
|
|
}
|
|
|
|
void operator()(const reversed_type_impl& t, const void* v) {
|
|
// This is unreachable since reversed_type_impl is used only
|
|
// in the tables. The function gets the underlying type.
|
|
abort();
|
|
}
|
|
|
|
void operator()(const map_type_impl& t, const std::vector<std::pair<data_value, data_value>>* v) {
|
|
// returns the table { k1 = v1, k2 = v2, ...}
|
|
lua_createtable(l, 0, v->size());
|
|
for (const auto& p : *v) {
|
|
push_argument(l, p.first);
|
|
push_argument(l, p.second);
|
|
lua_rawset(l, -3);
|
|
}
|
|
}
|
|
|
|
void operator()(const user_type_impl& t, const std::vector<data_value>* v) {
|
|
// returns the table { field1 = v1, field2 = v2, ...}
|
|
lua_createtable(l, 0, v->size());
|
|
for (int i = 0, n = v->size(); i < n; ++i) {
|
|
push_sstring(l, t.field_name_as_string(i));
|
|
push_argument(l, (*v)[i]);
|
|
lua_rawset(l, -3);
|
|
}
|
|
}
|
|
|
|
void operator()(const set_type_impl& t, const std::vector<data_value>* v) {
|
|
// returns the table { v1 = true, v2 = true, ...}
|
|
lua_createtable(l, 0, v->size());
|
|
for (const data_value& dv : *v) {
|
|
push_argument(l, dv);
|
|
lua_pushboolean(l, true);
|
|
lua_rawset(l, -3);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void operator()(const concrete_type<std::vector<data_value>, T>& t, const std::vector<data_value>* v) {
|
|
// returns the table {v1, v2, ...}
|
|
lua_createtable(l, v->size(), 0);
|
|
int i = 0;
|
|
for (const data_value& dv : *v) {
|
|
push_argument(l, dv);
|
|
lua_rawseti(l, -2, ++i);
|
|
}
|
|
}
|
|
|
|
void operator()(const boolean_type_impl& t, const emptyable<bool>* v) {
|
|
lua_pushboolean(l, *v);
|
|
}
|
|
|
|
template <typename T>
|
|
void operator()(const floating_type_impl<T>& t, const emptyable<T>* v) {
|
|
// floats are converted to double
|
|
lua_pushnumber(l, *v);
|
|
}
|
|
|
|
template <typename T>
|
|
void operator()(const integer_type_impl<T>& t, const emptyable<T>* v) {
|
|
// Integers are converted to 64 bits
|
|
lua_pushinteger(l, *v);
|
|
}
|
|
|
|
void operator()(const bytes_type_impl& t, const bytes* v) {
|
|
// lua strings can hold arbitrary blobs
|
|
lua_pushlstring(l, reinterpret_cast<const char*>(v->c_str()), v->size());
|
|
}
|
|
|
|
void operator()(const string_type_impl& t, const sstring* v) {
|
|
push_sstring(l, *v);
|
|
}
|
|
|
|
void operator()(const time_type_impl& t, const emptyable<int64_t>* v) {
|
|
// nanoseconds since midnight
|
|
lua_pushinteger(l, *v);
|
|
}
|
|
|
|
void operator()(const timestamp_date_base_class& t, const timestamp_date_base_class::native_type* v) {
|
|
// milliseconds since epoch
|
|
lua_pushinteger(l, v->get().time_since_epoch().count());
|
|
}
|
|
|
|
void operator()(const simple_date_type_impl& t, const emptyable<uint32_t>* v) {
|
|
// number of days since epoch + 2^31
|
|
lua_pushinteger(l, *v);
|
|
}
|
|
|
|
void operator()(const duration_type_impl& t, const emptyable<cql_duration>* v) {
|
|
// returns the table { months = v1, days = v2, nanoseconds = v3 }
|
|
const cql_duration& d = v->get();
|
|
lua_createtable(l, 3, 0);
|
|
|
|
lua_pushinteger(l, d.months);
|
|
lua_setfield(l, -2, "months");
|
|
|
|
lua_pushinteger(l, d.days);
|
|
lua_setfield(l, -2, "days");
|
|
|
|
lua_pushinteger(l, d.nanoseconds);
|
|
lua_setfield(l, -2, "nanoseconds");
|
|
}
|
|
|
|
void operator()(const inet_addr_type_impl& t, const emptyable<seastar::net::inet_address>* v) {
|
|
// returns a string
|
|
sstring s = fmt::to_string(v->get());
|
|
push_sstring(l, s);
|
|
}
|
|
|
|
void operator()(const concrete_type<utils::UUID>&, const emptyable<utils::UUID>* v) {
|
|
// returns a string
|
|
push_sstring(l, fmt::to_string(v->get())) ;
|
|
}
|
|
};
|
|
}
|
|
|
|
static void push_argument(lua_State* l, const data_value& arg) {
|
|
if (arg.is_null()) {
|
|
lua_pushnil(l);
|
|
return;
|
|
}
|
|
::visit(arg, to_lua_visitor{l});
|
|
}
|
|
|
|
// run the script for at most max_instructions
|
|
future<bytes_opt> lua::run_script(lua::bitcode_view bitcode, const std::vector<data_value>& values, data_type return_type, const lua::runtime_config& cfg) {
|
|
lua_slice_state l = load_script(cfg, bitcode);
|
|
unsigned nargs = values.size();
|
|
if (!lua_checkstack(l, nargs)) {
|
|
throw std::runtime_error("could push args to the stack");
|
|
}
|
|
for (const data_value& arg : values) {
|
|
push_argument(l, arg);
|
|
}
|
|
|
|
// We don't update the timeout once we start executing the function
|
|
using millisecond = std::chrono::duration<double, std::milli>;
|
|
using duration = std::chrono::system_clock::duration;
|
|
duration elapsed{0};
|
|
duration timeout = std::chrono::duration_cast<duration>(millisecond(cfg.timeout_in_ms));
|
|
return repeat_until_value([l = std::move(l), elapsed, return_type, nargs, timeout = std::move(timeout)] () mutable {
|
|
// Set the hook before resuming. We have to do it here since the hook can reset itself
|
|
// if it detects we are spending too much time in C.
|
|
// The hook will be called after 1000 instructions.
|
|
lua_sethook(l, debug_hook, LUA_MASKCALL | LUA_MASKCOUNT, 1000);
|
|
auto start = ::now();
|
|
LUA_504_PLUS(int nresults;)
|
|
switch (lua_resume(l, nullptr, nargs LUA_504_PLUS(, &nresults))) {
|
|
case LUA_OK:
|
|
return make_ready_future<std::optional<bytes_opt>>(convert_return(l, return_type));
|
|
case LUA_YIELD: {
|
|
nargs = 0;
|
|
elapsed += ::now() - start;
|
|
if (elapsed > timeout) {
|
|
millisecond ms = elapsed;
|
|
throw exceptions::invalid_request_exception(format("lua execution timeout: {}ms elapsed", ms.count()));
|
|
}
|
|
return make_ready_future<std::optional<bytes_opt>>(std::nullopt);
|
|
}
|
|
default:
|
|
throw exceptions::invalid_request_exception(std::string("lua execution failed: ") +
|
|
lua_tostring(l, -1));
|
|
}
|
|
});
|
|
}
|
|
|
|
namespace lua {
|
|
|
|
void register_metatables(lua_State* l) {
|
|
luaL_newmetatable(l, scylla_decimal_metatable_name);
|
|
lua_pushvalue(l, -1);
|
|
lua_setfield(l, -2, "__index");
|
|
luaL_setfuncs(l, decimal_methods, 0);
|
|
lua_pop(l, 1);
|
|
}
|
|
|
|
void push_data_value(lua_State* l, const data_value& value) {
|
|
push_argument(l, value);
|
|
}
|
|
|
|
data_value pop_data_value(lua_State* l, const data_type& type) {
|
|
return convert_from_lua(l, type);
|
|
}
|
|
|
|
}
|