diff --git a/locator/load_sketch.hh b/locator/load_sketch.hh index 5598fba707..b298311a02 100644 --- a/locator/load_sketch.hh +++ b/locator/load_sketch.hh @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -33,21 +34,39 @@ class load_sketch { shard_id id; load_type load; }; - // Used in a max-heap to yield lower load first. + + // Less-comparator which orders by load first (ascending), and then by shard id (ascending). struct shard_load_cmp { bool operator()(const shard_load& a, const shard_load& b) const { - return a.load > b.load; + return a.load == b.load ? a.id < b.id : a.load < b.load; } }; + struct node_load { - std::vector _shards; + absl::btree_set _shards_by_load; + std::vector _shards; load_type _load = 0; node_load(size_t shard_count) : _shards(shard_count) { - shard_id next_shard = 0; - for (auto&& s : _shards) { - s.id = next_shard++; - s.load = 0; + for (shard_id i = 0; i < shard_count; ++i) { + _shards[i] = 0; + } + } + + void update_shard_load(shard_id shard, load_type load_delta) { + _load += load_delta; + + auto old_load = _shards[shard]; + auto new_load = old_load + load_delta; + _shards_by_load.erase(shard_load{shard, old_load}); + _shards[shard] = new_load; + _shards_by_load.insert(shard_load{shard, new_load}); + } + + void populate_shards_by_load() { + _shards_by_load.clear(); + for (shard_id i = 0; i < _shards.size(); ++i) { + _shards_by_load.insert(shard_load{i, _shards[i]}); } } @@ -81,7 +100,8 @@ private: node_load& n = _nodes.at(replica.host); if (replica.shard < n._shards.size()) { n.load() += 1; - n._shards[replica.shard].load += 1; + n._shards[replica.shard] += 1; + // Note: as an optimization, _shards_by_load is populated later in populate_shards_by_load() } } return make_ready_future<>(); @@ -104,8 +124,8 @@ public: } } - for (auto&& n : _nodes) { - std::make_heap(n.second._shards.begin(), n.second._shards.end(), shard_load_cmp()); + for (auto&& [id, n] : _nodes) { + n.populate_shards_by_load(); } } @@ -116,39 +136,24 @@ public: if (shard_count == 0) { throw std::runtime_error(format("Shard count not known for node {}", node)); } - _nodes.emplace(node, node_load{shard_count}); + auto [i, _] = _nodes.emplace(node, node_load{shard_count}); + i->second.populate_shards_by_load(); } auto& n = _nodes.at(node); - std::pop_heap(n._shards.begin(), n._shards.end(), shard_load_cmp()); - shard_load& s = n._shards.back(); + const shard_load& s = *n._shards_by_load.begin(); auto shard = s.id; - s.load += 1; - n.load() += 1; - std::push_heap(n._shards.begin(), n._shards.end(), shard_load_cmp()); + n.update_shard_load(shard, 1); return shard; } void unload(host_id node, shard_id shard) { auto& n = _nodes.at(node); - for (auto& shard_load : n._shards) { - if (shard_load.id == shard) { - assert(shard_load.load > 0); - --shard_load.load; - break; - } - } - std::make_heap(n._shards.begin(), n._shards.end(), shard_load_cmp()); + n.update_shard_load(shard, -1); } void pick(host_id node, shard_id shard) { auto& n = _nodes.at(node); - for (auto& shard_load : n._shards) { - if (shard_load.id == shard) { - ++shard_load.load; - break; - } - } - std::make_heap(n._shards.begin(), n._shards.end(), shard_load_cmp()); + n.update_shard_load(shard, 1); } load_type get_load(host_id node) const { @@ -201,8 +206,8 @@ public: min_max_tracker minmax; if (_nodes.contains(node)) { auto& n = _nodes.at(node); - for (auto&& s: n._shards) { - minmax.update(s.load); + for (auto&& load: n._shards) { + minmax.update(load); } } else { minmax.update(0);