Files
scylladb/utils/base64.cc
Nadav Har'El 6d1abc5b2c utils/base64: fix misleading code and comment (no functional change)
utils/base64.cc had some strange code with a strange comment in
base64_begins_with().

The code had

        base.substr(operand.size() - 4, operand.size())

The comment claims that this is "last 4 bytes of base64-encoded string",
but this comment is misleading - operand is typically shorter than base
(this this whole point of the base64_begins_with()), so the real
intention of the code is not to find the *last* 4 bytes of base, but rather
the *next* four bytes after the (operand.size() - 4) which we already copied.
These four bytes that may need the full power of base64_decode_string()
because they may or may not contain padding.

But, if we really want the next 4 bytes, why pass operand.size() as the
length of the substring? operand.size() is at least 4 (it's a mutiple of
4, and if it's 0 we returned earlier), but it could me more. We don't
need more, we just need 4. It's not really wrong to take more than 4 (so
this patch doesn't *fix* any bug), but can be wasteful. So this code
should be:

        base.substr(operand.size() - 4, 4)

We already have in test/boost/alternator_unit_test.cc a test,
test_base64_begins_with that takes encoded base64 strings up to 12
characters in length (corresponding to decoded strings up to 8 chars),
and substrings from length 0 to the base string's length, and check
that test_base64_begins_with succeeds.

Signed-off-by: Nadav Har'El <nyh@scylladb.com>

Closes scylladb/scylladb#25712
2025-09-01 08:57:50 +03:00

168 lines
5.4 KiB
C++

/*
* Copyright 2019-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "base64.hh"
#include <seastar/core/format.hh>
// Arrays for quickly converting to and from an integer between 0 and 63,
// and the character used in base64 encoding to represent it.
static class base64_chars {
public:
static constexpr const char to[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static constexpr uint8_t invalid_char = 255;
uint8_t from[256];
base64_chars() {
static_assert(sizeof(to) == 64 + 1);
for (int i = 0; i < 256; i++) {
from[i] = invalid_char; // signal invalid character
}
for (int i = 0; i < 64; i++) {
from[(unsigned) to[i]] = i;
}
}
} base64_chars;
std::string base64_encode(bytes_view in) {
std::string ret;
ret.reserve(((4 * in.size() / 3) + 3) & ~3);
int i = 0;
unsigned char chunk3[3]; // chunk of input
for (auto byte : in) {
chunk3[i++] = byte;
if (i == 3) {
ret += base64_chars.to[ (chunk3[0] & 0xfc) >> 2 ];
ret += base64_chars.to[ ((chunk3[0] & 0x03) << 4) + ((chunk3[1] & 0xf0) >> 4) ];
ret += base64_chars.to[ ((chunk3[1] & 0x0f) << 2) + ((chunk3[2] & 0xc0) >> 6) ];
ret += base64_chars.to[ chunk3[2] & 0x3f ];
i = 0;
}
}
if (i) {
// i can be 1 or 2.
for(int j = i; j < 3; j++)
chunk3[j] = '\0';
ret += base64_chars.to[ ( chunk3[0] & 0xfc) >> 2 ];
ret += base64_chars.to[ ((chunk3[0] & 0x03) << 4) + ((chunk3[1] & 0xf0) >> 4) ];
if (i == 2) {
ret += base64_chars.to[ ((chunk3[1] & 0x0f) << 2) + ((chunk3[2] & 0xc0) >> 6) ];
} else {
ret += '=';
}
ret += '=';
}
return ret;
}
static size_t base64_padding_len(std::string_view str) {
size_t padding = 0;
padding += (!str.empty() && str.back() == '=');
padding += (str.size() > 1 && *(str.end() - 2) == '=');
return padding;
}
static void base64_trim_padding(std::string_view& str) {
if (str.size() % 4 != 0) {
throw std::invalid_argument(format("Base64 encoded length is expected a multiple of 4 bytes but found: {}", str.size()));
}
str.remove_suffix(base64_padding_len(str));
}
static std::string base64_decode_string(std::string_view in) {
base64_trim_padding(in);
int i = 0;
int8_t chunk4[4]; // chunk of input, each byte converted to 0..63;
std::string ret;
ret.reserve(in.size() * 3 / 4);
for (unsigned char c : in) {
uint8_t dc = base64_chars.from[c];
if (dc == base64_chars::invalid_char) {
throw std::invalid_argument(format("Invalid Base64 character: '{}'", char(c)));
}
chunk4[i++] = dc;
if (i == 4) {
ret += (chunk4[0] << 2) + ((chunk4[1] & 0x30) >> 4);
ret += ((chunk4[1] & 0xf) << 4) + ((chunk4[2] & 0x3c) >> 2);
ret += ((chunk4[2] & 0x3) << 6) + chunk4[3];
i = 0;
}
}
if (i) {
// i can be 2 or 3, meaning 1 or 2 more output characters
if (i>=2)
ret += (chunk4[0] << 2) + ((chunk4[1] & 0x30) >> 4);
if (i==3)
ret += ((chunk4[1] & 0xf) << 4) + ((chunk4[2] & 0x3c) >> 2);
}
return ret;
}
bytes base64_decode(std::string_view in) {
// FIXME: This copy is sad. The problem is we need back "bytes"
// but "bytes" doesn't have efficient append and std::string.
// To fix this we need to use bytes' "uninitialized" feature.
std::string ret = base64_decode_string(in);
return bytes(ret.begin(), ret.end());
}
size_t base64_decoded_len(std::string_view str) {
return str.size() / 4 * 3 - base64_padding_len(str);
}
bool base64_begins_with(std::string_view base, std::string_view operand) {
if (base.size() < operand.size() || base.size() % 4 != 0 || operand.size() % 4 != 0) {
return false;
}
if (base64_padding_len(operand) == 0) {
return base.starts_with(operand);
}
const std::string_view unpadded_base_prefix = base.substr(0, operand.size() - 4);
const std::string_view unpadded_operand = operand.substr(0, operand.size() - 4);
if (unpadded_base_prefix != unpadded_operand) {
return false;
}
// Decode and compare next 4 bytes of base64-encoded strings
const std::string base_remainder = base64_decode_string(base.substr(operand.size() - 4, 4));
const std::string operand_remainder = base64_decode_string(operand.substr(operand.size() - 4));
return base_remainder.starts_with(operand_remainder);
}
std::string base64url_encode(bytes_view in) {
std::string str = base64_encode(in);
for (char& c : str) {
if (c == '+') {
c = '-';
} else if (c == '/') {
c = '_';
}
}
str.erase(std::find(str.begin(), str.end(), '='), str.end());
return str;
}
bytes base64url_decode(std::string_view in) {
std::string str{in};
size_t mod = str.size() % 4;
if (mod == 1) {
std::invalid_argument(seastar::format("Base64 encoded length is invalid: {}", str.size()));
} else if (mod == 2) {
str.append("==", 2);
} else if (mod == 3) {
str.append("=", 1);
}
for (char& c : str) {
if (c == '-') {
c = '+';
} else if (c == '_') {
c = '/';
}
}
return base64_decode(str);
}