/* * Copyright (C) 2015-present ScyllaDB */ /* * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 */ #pragma once #include "utils/assert.hh" #include #include #include #include #include #include #include #include /** * * Allows to take full advantage of compile-time information when operating * on a set of enum values. * * Examples: * * enum class x { A, B, C }; * using my_enum = super_enum; * using my_enumset = enum_set; * * static_assert(my_enumset::frozen::contains(), "it should..."); * * SCYLLA_ASSERT(my_enumset::frozen::contains(my_enumset::prepare())); * * SCYLLA_ASSERT(my_enumset::frozen::contains(x::A)); * */ template struct super_enum { using enum_type = EnumType; template struct max { static constexpr enum_type max_of(enum_type a, enum_type b) { return a > b ? a : b; } template static constexpr enum_type get() { return max_of(first, get()); } template static constexpr enum_type get() { return first; } static constexpr enum_type value = get(); }; template struct min { static constexpr enum_type min_of(enum_type a, enum_type b) { return a < b ? a : b; } template static constexpr enum_type get() { return min_of(first, get()); } template static constexpr enum_type get() { return first; } static constexpr enum_type value = get(); }; using sequence_type = typename std::underlying_type::type; template struct valid_sequence { static constexpr bool apply(sequence_type v) noexcept { return (v == static_cast(first)) || valid_sequence::apply(v); } }; template struct valid_sequence { static constexpr bool apply(sequence_type v) noexcept { return v == static_cast(first); } }; static constexpr bool is_valid_sequence(sequence_type v) noexcept { return valid_sequence::apply(v); } template static constexpr sequence_type sequence_for() { return static_cast(Elem); } static sequence_type sequence_for(enum_type elem) { return static_cast(elem); } static constexpr sequence_type max_sequence = sequence_for::value>(); static constexpr sequence_type min_sequence = sequence_for::value>(); static_assert(min_sequence >= 0, "negative enum values unsupported"); }; class bad_enum_set_mask : public std::invalid_argument { public: bad_enum_set_mask() : std::invalid_argument("Bit mask contains invalid enumeration indices.") { } }; template class enum_set { public: using mask_type = size_t; // TODO: use the smallest sufficient type using enum_type = typename Enum::enum_type; private: static constexpr int mask_digits = std::numeric_limits::digits; using mask_iterator = seastar::bitsets::set_iterator; mask_type _mask; constexpr enum_set(mask_type mask) : _mask(mask) {} template static constexpr unsigned shift_for() { return Enum::template sequence_for(); } static auto make_iterator(mask_iterator iter) { return boost::make_transform_iterator(std::move(iter), [](typename Enum::sequence_type s) { return enum_type(s); }); } public: using iterator = std::invoke_result_t; constexpr enum_set() : _mask(0) {} /** * \throws \ref bad_enum_set_mask */ static constexpr enum_set from_mask(mask_type mask) { const auto bit_range = seastar::bitsets::for_each_set(std::bitset(mask)); if (!std::all_of(bit_range.begin(), bit_range.end(), &Enum::is_valid_sequence)) { throw bad_enum_set_mask(); } return enum_set(mask); } static constexpr mask_type full_mask() { return ~(std::numeric_limits::max() << (Enum::max_sequence + 1)); } static constexpr enum_set full() { return enum_set(full_mask()); } static inline mask_type mask_for(enum_type e) { return mask_type(1) << Enum::sequence_for(e); } template static constexpr mask_type mask_for() { return mask_type(1) << shift_for(); } struct prepared { mask_type mask; bool operator==(const prepared& o) const { return mask == o.mask; } }; static prepared prepare(enum_type e) { return {mask_for(e)}; } template static constexpr prepared prepare() { return {mask_for()}; } static_assert(std::numeric_limits::max() >= ((size_t)1 << Enum::max_sequence), "mask type too small"); template bool contains() const { return bool(_mask & mask_for()); } bool contains(enum_type e) const { return bool(_mask & mask_for(e)); } bool intersects(const enum_set& other) const { return bool(_mask & other._mask); } template void remove() { _mask &= ~mask_for(); } void remove(enum_type e) { _mask &= ~mask_for(e); } template void set() { _mask |= mask_for(); } template void set_if(bool condition) { _mask |= mask_type(condition) << shift_for(); } void set(enum_type e) { _mask |= mask_for(e); } template void toggle() { _mask ^= mask_for(); } void toggle(enum_type e) { _mask ^= mask_for(e); } void add(const enum_set& other) { _mask |= other._mask; } explicit operator bool() const { return bool(_mask); } mask_type mask() const { return _mask; } iterator begin() const { return make_iterator(mask_iterator(_mask)); } iterator end() const { return make_iterator(mask_iterator(0)); } template struct frozen { template static constexpr mask_type make_mask() { return mask_for(); } static constexpr mask_type make_mask() { return 0; } template static constexpr mask_type make_mask() { return mask_for() | make_mask(); } static constexpr mask_type mask = make_mask(); template static constexpr bool contains() { return mask & mask_for(); } static bool contains(enum_type e) { return mask & mask_for(e); } static bool contains(prepared e) { return mask & e.mask; } static constexpr enum_set unfreeze() { return enum_set(mask); } }; template static constexpr enum_set of() { return frozen::unfreeze(); } };