diff --git a/utils/stream_compressor.cc b/utils/stream_compressor.cc index 5c19a9edd4..a5ed557964 100644 --- a/utils/stream_compressor.cc +++ b/utils/stream_compressor.cc @@ -68,12 +68,48 @@ static void check_zstd(size_t ret, const char* text) { } } +// IMPORTANT: with this constant, we are hardcoding a max window size in the decoder. +// +// We do this to make the allocator usage by ZSTD more predictable. +// (If you let zstd alloc and free memory on its own, +// it has some "helpful" heuristics which sometimes reallocate the buffers +// even if the context is reused, and we want to avoid that. +// See https://github.com/facebook/zstd/commit/3d523c741be041f17c28e43b89ab6dfcaee281d2) +// +// But the hardcoding means that the decoder won't be able to handle frames with a window +// size bigger than this. (The decoding will return an error that there's not enough memory). +// +// That means that the window size becomes a part of the protocol: +// the compressor mustn't send messages with a greater window size. +// +// If you wish to enlarge the window size for whatever reason, +// you have to ensure that the new version of the compressor isn't talking +// to a decompressor with a window size hardcoded to something too small. +// +// (In case of RPC compression, you can do this by bumping the COMPRESSOR_NAME +// in advanced_rpc_compressor.cc when releasing the new Scylla version). +constexpr size_t ZSTD_WINDOW_SIZE_LOG = 17; + zstd_dstream::zstd_dstream() { - _ctx.reset(ZSTD_createDStream()); - check_zstd(ZSTD_DCtx_setParameter(_ctx.get(), ZSTD_d_format, ZSTD_f_zstd1_magicless), "ZSTD_CCtx_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)"); + // IMPORTANT: window size set in the decompressor can't be smaller + // must be bigger than the window size of the compressor + // talking to us. + const auto window_size = 1 << ZSTD_WINDOW_SIZE_LOG; + const size_t workspace_size = ZSTD_estimateDStreamSize(window_size); + { + // zstd needs a large contiguous allocation, it's unavoidable. + const memory::scoped_large_allocation_warning_threshold slawt{1024*1024+1}; + _ctx.reset(static_cast(malloc(workspace_size))); + } if (!_ctx) { throw std::bad_alloc(); } + if (!ZSTD_initStaticDStream(_ctx.get(), workspace_size)) { + throw std::runtime_error("ZSTD_initStaticCStream() failed"); + } + + // IMPORTANT: this must match the compressor. + check_zstd(ZSTD_DCtx_setParameter(_ctx.get(), ZSTD_d_format, ZSTD_f_zstd1_magicless), "ZSTD_CCtx_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)"); } void zstd_dstream::reset() noexcept { @@ -110,17 +146,41 @@ void zstd_dstream::set_dict(const ZSTD_DDict* dict) { } zstd_cstream::zstd_cstream() { - _ctx.reset(ZSTD_createCStream()); + struct params_deleter { + void operator()(ZSTD_CCtx_params* params) const noexcept { + ZSTD_freeCCtxParams(params); + } + }; + + std::unique_ptr params(ZSTD_createCCtxParams()); + // For now, we hardcode a 128 kiB window and the lowest compression level here. + // We don't need more for RPC compression (or rather: we value lower CPU + // usage over mildly stronger compression). + auto compression_level = 1; + check_zstd(ZSTD_CCtxParams_init(params.get(), compression_level), "ZSTD_Cctx_params_init(.., 1)"); + // IMPORTANT: this must match the decompressor. + check_zstd(ZSTD_CCtxParams_setParameter(params.get(), ZSTD_c_format, ZSTD_f_zstd1_magicless), "ZSTD_CCTxParams_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)"); + check_zstd(ZSTD_CCtxParams_setParameter(params.get(), ZSTD_c_contentSizeFlag, 0), "ZSTD_CCTxParams_setParameter(.., ZSTD_c_contentSizeFlag, 0)"); + check_zstd(ZSTD_CCtxParams_setParameter(params.get(), ZSTD_c_checksumFlag, 0), "ZSTD_CCTxParams_setParameter(.., ZSTD_c_checksumFlag, 0)"); + check_zstd(ZSTD_CCtxParams_setParameter(params.get(), ZSTD_c_dictIDFlag, 0), "ZSTD_CCTxParams_setParameter(.., ZSTD_c_dictIDFlag, 0)"); + // IMPORTANT: window size in compressor mustn't be greater than + // the max window size handlable by the decompressor. + check_zstd(ZSTD_CCtxParams_setParameter(params.get(), ZSTD_c_windowLog, ZSTD_WINDOW_SIZE_LOG), "ZSTD_CCtx_setParameter(.., ZSTD_c_windowLog, 17)"); + + const size_t workspace_size = ZSTD_estimateCStreamSize_usingCCtxParams(params.get()); + { + // zstd needs a large contiguous allocation, it's unavoidable. + const memory::scoped_large_allocation_warning_threshold slawt{1024*1024+1}; + _ctx.reset(static_cast(malloc(workspace_size))); + } if (!_ctx) { throw std::bad_alloc(); } - // For now, we hardcode a 128 kiB window and the lowest compression level here. - check_zstd(ZSTD_initCStream(_ctx.get(), 1), "ZSTD_initCStream(.., 1)"); - check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_format, ZSTD_f_zstd1_magicless), "ZSTD_CCtx_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)"); - check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_contentSizeFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_contentSizeFlag, 0)"); - check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_checksumFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_checksumFlag, 0)"); - check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_dictIDFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_dictIDFlag, 0)"); - check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_windowLog, 17), "ZSTD_CCtx_setParameter(.., ZSTD_c_windowLog, 17)"); + if (!ZSTD_initStaticCStream(_ctx.get(), workspace_size)) { + throw std::runtime_error("ZSTD_initStaticCStream() failed"); + } + + check_zstd(ZSTD_CCtx_setParametersUsingCCtxParams(_ctx.get(), params.get()), "ZSTD_CCtx_setParametersUsingCCtxParams(..)"); } size_t zstd_cstream::compress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, ZSTD_EndDirective end) { diff --git a/utils/stream_compressor.hh b/utils/stream_compressor.hh index 4516a6e59b..b15a5d7d83 100644 --- a/utils/stream_compressor.hh +++ b/utils/stream_compressor.hh @@ -57,7 +57,7 @@ struct raw_stream final : public stream_compressor, public stream_decompressor { class zstd_dstream final : public stream_decompressor { struct ctx_deleter { void operator()(ZSTD_DStream* stream) const noexcept { - ZSTD_freeDStream(stream); + free(stream); } }; std::unique_ptr _ctx; @@ -76,7 +76,7 @@ public: class zstd_cstream final : public stream_compressor { struct ctx_deleter { void operator()(ZSTD_CStream* stream) const noexcept { - ZSTD_freeCStream(stream); + free(stream); } }; std::unique_ptr _ctx;