/*
* Copyright (C) 2015 ScyllaDB
*/
/*
* This file is part of Scylla.
*
* Scylla is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Scylla is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Scylla. If not, see .
*/
#pragma once
#include
#include
#include
#include
#include
#include
#include
#include
/**
*
* Allows to take full advantage of compile-time information when operating
* on a set of enum values.
*
* Examples:
*
* enum class x { A, B, C };
* using my_enum = super_enum;
* using my_enumset = enum_set;
*
* static_assert(my_enumset::frozen::contains(), "it should...");
*
* assert(my_enumset::frozen::contains(my_enumset::prepare()));
*
* assert(my_enumset::frozen::contains(x::A));
*
*/
template
struct super_enum {
using enum_type = EnumType;
template
struct max {
static constexpr enum_type max_of(enum_type a, enum_type b) {
return a > b ? a : b;
}
template
static constexpr enum_type get() {
return max_of(first, get());
}
template
static constexpr enum_type get() { return first; }
static constexpr enum_type value = get();
};
template
struct min {
static constexpr enum_type min_of(enum_type a, enum_type b) {
return a < b ? a : b;
}
template
static constexpr enum_type get() {
return min_of(first, get());
}
template
static constexpr enum_type get() { return first; }
static constexpr enum_type value = get();
};
using sequence_type = typename std::underlying_type::type;
template
struct valid_sequence {
static constexpr bool apply(sequence_type v) noexcept {
return (v == static_cast(first)) || valid_sequence::apply(v);
}
};
template
struct valid_sequence {
static constexpr bool apply(sequence_type v) noexcept {
return v == static_cast(first);
}
};
static constexpr bool is_valid_sequence(sequence_type v) noexcept {
return valid_sequence::apply(v);
}
template
static constexpr sequence_type sequence_for() {
return static_cast(Elem);
}
static sequence_type sequence_for(enum_type elem) {
return static_cast(elem);
}
static constexpr sequence_type max_sequence = sequence_for::value>();
static constexpr sequence_type min_sequence = sequence_for::value>();
static_assert(min_sequence >= 0, "negative enum values unsupported");
};
class bad_enum_set_mask : public std::invalid_argument {
public:
bad_enum_set_mask() : std::invalid_argument("Bit mask contains invalid enumeration indices.") {
}
};
template
class enum_set {
public:
using mask_type = size_t; // TODO: use the smallest sufficient type
using enum_type = typename Enum::enum_type;
private:
static constexpr int mask_digits = std::numeric_limits::digits;
using mask_iterator = seastar::bitsets::set_iterator;
mask_type _mask;
constexpr enum_set(mask_type mask) : _mask(mask) {}
template
static constexpr unsigned shift_for() {
return Enum::template sequence_for();
}
static auto make_iterator(mask_iterator iter) {
return boost::make_transform_iterator(std::move(iter), [](typename Enum::sequence_type s) {
return enum_type(s);
});
}
public:
using iterator = std::invoke_result_t;
constexpr enum_set() : _mask(0) {}
/**
* \throws \ref bad_enum_set_mask
*/
static constexpr enum_set from_mask(mask_type mask) {
const auto bit_range = seastar::bitsets::for_each_set(std::bitset(mask));
if (!std::all_of(bit_range.begin(), bit_range.end(), &Enum::is_valid_sequence)) {
throw bad_enum_set_mask();
}
return enum_set(mask);
}
static constexpr mask_type full_mask() {
return ~(std::numeric_limits::max() << (Enum::max_sequence + 1));
}
static constexpr enum_set full() {
return enum_set(full_mask());
}
static inline mask_type mask_for(enum_type e) {
return mask_type(1) << Enum::sequence_for(e);
}
template
static constexpr mask_type mask_for() {
return mask_type(1) << shift_for();
}
struct prepared {
mask_type mask;
bool operator==(const prepared& o) const {
return mask == o.mask;
}
};
static prepared prepare(enum_type e) {
return {mask_for(e)};
}
template
static constexpr prepared prepare() {
return {mask_for()};
}
static_assert(std::numeric_limits::max() >= ((size_t)1 << Enum::max_sequence), "mask type too small");
template
bool contains() const {
return bool(_mask & mask_for());
}
bool contains(enum_type e) const {
return bool(_mask & mask_for(e));
}
template
void remove() {
_mask &= ~mask_for();
}
void remove(enum_type e) {
_mask &= ~mask_for(e);
}
template
void set() {
_mask |= mask_for();
}
template
void set_if(bool condition) {
_mask |= mask_type(condition) << shift_for();
}
void set(enum_type e) {
_mask |= mask_for(e);
}
explicit operator bool() const {
return bool(_mask);
}
mask_type mask() const {
return _mask;
}
iterator begin() const {
return make_iterator(mask_iterator(_mask));
}
iterator end() const {
return make_iterator(mask_iterator(0));
}
template
struct frozen {
template
static constexpr mask_type make_mask() {
return mask_for();
}
static constexpr mask_type make_mask() {
return 0;
}
template
static constexpr mask_type make_mask() {
return mask_for() | make_mask();
}
static constexpr mask_type mask = make_mask();
template
static constexpr bool contains() {
return mask & mask_for();
}
static bool contains(enum_type e) {
return mask & mask_for(e);
}
static bool contains(prepared e) {
return mask & e.mask;
}
static enum_set unfreeze() {
return enum_set(mask);
}
};
template
static enum_set of() {
return frozen::unfreeze();
}
};