diff --git a/serializer_impl.hh b/serializer_impl.hh index 9aa0356489..2a9dab14b3 100644 --- a/serializer_impl.hh +++ b/serializer_impl.hh @@ -791,15 +791,50 @@ unknown_variant_type deserialize(Input& in, boost::type) { // using a range. // Use begin() and end() to iterate through the frozen vector, // deserializing (or skipping) one element at a time. -template +template class vector_deserializer { public: using value_type = T; - using input_stream = InputStream; + using input_stream = utils::input_stream; private: input_stream _in; size_t _size; + utils::chunked_vector _substreams; + + void fill_substreams() requires (!IsForward) { + input_stream in = _in; + input_stream in2 = _in; + for (size_t i = 0; i < size(); ++i) { + size_t old_size = in.size(); + serializer::skip(in); + size_t new_size = in.size(); + + _substreams.push_back(in2.read_substream(old_size - new_size)); + } + } + + struct forward_iterator_data { + input_stream _in = simple_input_stream(); + void skip() { + serializer::skip(_in); + } + value_type deserialize_next() { + return deserialize(_in, boost::type()); + } + }; + struct reverse_iterator_data { + std::reverse_iterator::const_iterator> _substream_it; + void skip() { + ++_substream_it; + } + value_type deserialize_next() { + input_stream is = *_substream_it; + ++_substream_it; + return deserialize(is, boost::type()); + } + }; + public: vector_deserializer() noexcept @@ -810,7 +845,11 @@ public: explicit vector_deserializer(input_stream in) : _in(std::move(in)) , _size(deserialize(_in, boost::type())) - { } + { + if constexpr (!IsForward) { + fill_substreams(); + } + } // Get the number of items in the vector size_t size() const noexcept { @@ -823,13 +862,18 @@ public: // Input iterator class iterator { - input_stream _in; + // _idx is the distance from .begin(). It is used only for comparing iterators. size_t _idx = 0; bool _consumed = false; + std::conditional_t _data; - iterator(input_stream in, size_t idx) noexcept - : _in(in) - , _idx(idx) + iterator(input_stream in, size_t idx) noexcept requires(IsForward) + : _idx(idx) + , _data{in} + { } + iterator(decltype(reverse_iterator_data::_substream_it) substreams, size_t idx) noexcept requires(!IsForward) + : _idx(idx) + , _data{substreams} { } friend class vector_deserializer; @@ -840,7 +884,7 @@ public: using reference = value_type&; using difference_type = ssize_t; - iterator() noexcept : _in(simple_input_stream()) {} + iterator() noexcept = default; bool operator==(const iterator& it) const noexcept { return _idx == it._idx; @@ -849,17 +893,14 @@ public: // Deserializes and returns the item, effectively incrementing the iterator.. value_type operator*() const { auto zis = const_cast(this); - auto item = deserialize(zis->_in, boost::type()); zis->_idx++; zis->_consumed = true; - return item; + return zis->_data.deserialize_next(); } iterator& operator++() { if (!_consumed) { - serializer::skip(_in); - // auto len = read_frame_size(); - // _in.skip(len); + _data.skip(); ++_idx; } else { _consumed = false; @@ -882,25 +923,46 @@ public: static_assert(std::input_iterator); static_assert(std::sentinel_for); - iterator begin() noexcept { - return iterator(_in, 0); + iterator begin() noexcept requires(IsForward) { + return {_in, 0}; } - const_iterator begin() const noexcept { - return const_iterator(_in, 0); + const_iterator begin() const noexcept requires(IsForward) { + return {_in, 0}; } - const_iterator cbegin() const noexcept { - return const_iterator(_in, 0); + const_iterator cbegin() const noexcept requires(IsForward) { + return {_in, 0}; } - iterator end() noexcept { - return iterator(_in, _size); + iterator end() noexcept requires(IsForward) { + return {_in, _size}; } - const_iterator end() const noexcept { - return const_iterator(_in, _size); + const_iterator end() const noexcept requires(IsForward) { + return {_in, _size}; } - const_iterator cend() const noexcept { - return const_iterator(_in, _size); + const_iterator cend() const noexcept requires(IsForward) { + return {_in, _size}; } + + iterator begin() noexcept requires(!IsForward) { + return {_substreams.crbegin(), 0}; + } + const_iterator begin() const noexcept requires(!IsForward) { + return {_substreams.crbegin(), 0}; + } + const_iterator cbegin() const noexcept requires(!IsForward) { + return {_substreams.crbegin(), 0}; + } + + iterator end() noexcept requires(!IsForward) { + return {_substreams.crend(), _size}; + } + const_iterator end() const noexcept requires(!IsForward) { + return {_substreams.crend(), _size}; + } + const_iterator cend() const noexcept requires(!IsForward) { + return {_substreams.crend(), _size}; + } + }; static_assert(std::ranges::range>); diff --git a/test/boost/serialization_test.cc b/test/boost/serialization_test.cc index d40a7956c4..a89fbf9396 100644 --- a/test/boost/serialization_test.cc +++ b/test/boost/serialization_test.cc @@ -255,13 +255,71 @@ static void test_vector_deserializer(const std::vector& v) { } } +template +static void test_reverse_vector_deserializer(const std::vector& v) { + auto buf = ser::serialize_to_buffer(v); + auto in = simple_input_stream((const char*)buf.data(), buf.size()); + auto range = ser::vector_deserializer(in); + + auto test_equal = [] (const T& lhs, const T& rhs) { + if (lhs != rhs) { + throw std::runtime_error("compared values differ"); + } + }; + + auto required = [] (bool x) { + if (!x) { + throw std::runtime_error(format("failed requirment")); + } + }; + + { + auto vit = v.rbegin(); + auto rit = range.begin(); + while (rit != range.end()) { + test_equal(*rit, *vit); + ++rit; + ++vit; + } + required(vit == v.rend()); + } + + { + auto vit = v.rbegin(); + auto rit = range.begin(); + while (rit != range.end()) { + test_equal(*rit++, *vit++); + } + required(vit == v.rend()); + } + + { + auto cvit = v.crbegin(); + auto crit = range.cbegin(); + while (crit != range.cend()) { + test_equal(*crit++, *cvit++); + } + required(cvit == v.crend()); + } + + { + auto vit = v.rbegin(); + for (auto i : range) { + test_equal(i, *vit++); + } + } +} + BOOST_AUTO_TEST_CASE(vector_deserializer) { std::vector int_vect = { 3, 1, 4 }; test_vector_deserializer(int_vect); + test_reverse_vector_deserializer(int_vect); std::vector sstring_vect = { "testing", "one", "two", "three" }; test_vector_deserializer(sstring_vect); + test_reverse_vector_deserializer(sstring_vect); std::vector> opt_bool_vect = { true, false, {}, false, true }; test_vector_deserializer(opt_bool_vect); + test_reverse_vector_deserializer(opt_bool_vect); }