Files
scylladb/tools/patchelf.cc
Avi Kivity 0ae22a09d4 LICENSE: Update to version 1.1
Updated terms of non-commercial use (must be a never-customer).
2026-04-12 19:46:33 +03:00

435 lines
16 KiB
C++

// Copyright (C) 2025-present ScyllaDB
// SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <cstdio>
#include <sys/stat.h>
#include <getopt.h>
#include <elf.h>
#include <bit>
#include <concepts>
class elf_patcher {
private:
std::vector<uint8_t> _data;
bool _is_little_endian;
Elf64_Ehdr* _ehdr;
Elf64_Phdr* _phdr_table;
Elf64_Shdr* _shdr_table;
std::string _string_table;
std::string _input_filename;
template <std::integral T>
T read_value(const uint8_t* ptr) const {
T val = *reinterpret_cast<const T*>(ptr);
// Byte swap if host endianness differs from object endianness
bool need_swap = (std::endian::native == std::endian::little) != _is_little_endian;
return need_swap ? std::byteswap(val) : val;
}
template <std::integral T>
void write_value(uint8_t* ptr, T val) {
// Byte swap if host endianness differs from object endianness
bool need_swap = (std::endian::native == std::endian::little) != _is_little_endian;
*reinterpret_cast<T*>(ptr) = need_swap ? std::byteswap(val) : val;
}
void parse_elf_header() {
if (_data.size() < sizeof(Elf64_Ehdr)) {
throw std::runtime_error("File too small to be a valid ELF");
}
// Check ELF magic
if (std::memcmp(_data.data(), ELFMAG, SELFMAG) != 0) {
throw std::runtime_error("Not a valid ELF file");
}
_ehdr = reinterpret_cast<Elf64_Ehdr*>(_data.data());
// Check for 64-bit
if (_ehdr->e_ident[EI_CLASS] != ELFCLASS64) {
throw std::runtime_error("Only 64-bit ELF files are supported");
}
_is_little_endian = (_ehdr->e_ident[EI_DATA] == ELFDATA2LSB);
// Parse program header table
uint64_t phoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_phoff));
uint16_t phnum = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_phnum));
if (phoff + phnum * sizeof(Elf64_Phdr) > _data.size()) {
throw std::runtime_error("Invalid program header table");
}
_phdr_table = reinterpret_cast<Elf64_Phdr*>(_data.data() + phoff);
// Parse section header table
uint64_t shoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shoff));
uint16_t shnum = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shnum));
if (shoff != 0 && shnum > 0) {
if (shoff + shnum * sizeof(Elf64_Shdr) > _data.size()) {
throw std::runtime_error("Invalid section header table");
}
_shdr_table = reinterpret_cast<Elf64_Shdr*>(_data.data() + shoff);
// Load string table for section names
uint16_t shstrndx = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shstrndx));
if (shstrndx < shnum) {
Elf64_Shdr* strtab_shdr = &_shdr_table[shstrndx];
uint64_t strtab_offset = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&strtab_shdr->sh_offset));
uint64_t strtab_size = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&strtab_shdr->sh_size));
if (strtab_offset + strtab_size <= _data.size()) {
_string_table = std::string(reinterpret_cast<const char*>(_data.data() + strtab_offset), strtab_size);
}
}
} else {
_shdr_table = nullptr;
}
}
Elf64_Shdr* find_interp_section() {
if (!_shdr_table) {
return nullptr;
}
uint16_t shnum = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shnum));
for (int i = 0; i < shnum; i++) {
Elf64_Shdr* shdr = &_shdr_table[i];
uint32_t sh_name = read_value<uint32_t>(reinterpret_cast<uint8_t*>(&shdr->sh_name));
if (sh_name < _string_table.size()) {
const char* name = _string_table.c_str() + sh_name;
if (std::strcmp(name, ".interp") == 0) {
return shdr;
}
}
}
return nullptr;
}
Elf64_Phdr* find_interp_segment() {
uint16_t phnum = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_phnum));
for (int i = 0; i < phnum; i++) {
Elf64_Phdr* phdr = &_phdr_table[i];
uint32_t p_type = read_value<uint32_t>(reinterpret_cast<uint8_t*>(&phdr->p_type));
if (p_type == PT_INTERP) {
return phdr;
}
}
return nullptr;
}
public:
explicit elf_patcher(const std::string& filename) : _input_filename(filename) {
std::ifstream file(filename, std::ios::binary);
if (!file) {
throw std::runtime_error("Cannot open file: " + filename);
}
// Get file size
file.seekg(0, std::ios::end);
size_t size = file.tellg();
file.seekg(0, std::ios::beg);
// Read entire file
_data.resize(size);
file.read(reinterpret_cast<char*>(_data.data()), size);
if (!file) {
throw std::runtime_error("Error reading file: " + filename);
}
parse_elf_header();
}
std::string get_current_interpreter() {
// Try to get interpreter from .interp section first
Elf64_Shdr* shdr = find_interp_section();
if (shdr) {
uint64_t sh_offset = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&shdr->sh_offset));
uint64_t sh_size = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&shdr->sh_size));
if (sh_offset + sh_size <= _data.size()) {
const char* interp_start = reinterpret_cast<const char*>(_data.data() + sh_offset);
// Find null terminator or use full size
size_t len = 0;
while (len < sh_size && interp_start[len] != '\0') {
len++;
}
return std::string(interp_start, len);
}
}
// Fall back to PT_INTERP segment
Elf64_Phdr* phdr = find_interp_segment();
if (!phdr) {
return "";
}
uint64_t p_offset = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&phdr->p_offset));
uint64_t p_filesz = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&phdr->p_filesz));
if (p_offset + p_filesz > _data.size()) {
throw std::runtime_error("Invalid interpreter segment");
}
const char* interp_start = reinterpret_cast<const char*>(_data.data() + p_offset);
// Find null terminator or use full size
size_t len = 0;
while (len < p_filesz && interp_start[len] != '\0') {
len++;
}
return std::string(interp_start, len);
}
void set_interpreter(const std::string& new_interp) {
std::string new_interp_with_null = new_interp + '\0';
size_t new_interp_len = new_interp_with_null.size();
// Append new interpreter string to end of file
size_t new_offset = _data.size();
_data.resize(_data.size() + new_interp_len);
std::memcpy(_data.data() + new_offset, new_interp_with_null.c_str(), new_interp_len);
// Update pointers after potential reallocation
_ehdr = reinterpret_cast<Elf64_Ehdr*>(_data.data());
uint64_t phoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_phoff));
_phdr_table = reinterpret_cast<Elf64_Phdr*>(_data.data() + phoff);
uint64_t shoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shoff));
if (shoff != 0) {
_shdr_table = reinterpret_cast<Elf64_Shdr*>(_data.data() + shoff);
}
// Update PT_INTERP segment if it exists
Elf64_Phdr* phdr = find_interp_segment();
if (phdr) {
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&phdr->p_offset), new_offset);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&phdr->p_filesz), new_interp_len);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&phdr->p_memsz), new_interp_len);
}
// Update .interp section if it exists, or create one if section headers exist
if (_shdr_table) {
Elf64_Shdr* shdr = find_interp_section();
if (shdr) {
// Update existing .interp section
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&shdr->sh_offset), new_offset);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&shdr->sh_size), new_interp_len);
} else {
// Create new .interp section
create_interp_section(new_offset, new_interp_len);
}
}
}
private:
void create_interp_section(uint64_t offset, uint64_t size) {
// We need to expand the section header table and string table
uint16_t shnum = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shnum));
uint64_t shoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shoff));
if (shoff == 0 || shnum == 0) {
// No section headers exist, can't create sections
return;
}
// Add ".interp" to string table
std::string new_name = ".interp";
size_t name_offset = _string_table.size();
_string_table += new_name + '\0';
// Update string table in file
uint16_t shstrndx = read_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shstrndx));
if (shstrndx < shnum) {
Elf64_Shdr* strtab_shdr = &_shdr_table[shstrndx];
uint64_t strtab_offset = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&strtab_shdr->sh_offset));
// Resize data to accommodate new string table
size_t old_strtab_size = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&strtab_shdr->sh_size));
size_t new_strtab_size = _string_table.size();
// Move section header table to end if it's after the string table
if (shoff > strtab_offset + old_strtab_size) {
size_t new_shoff = _data.size();
_data.resize(_data.size() + (shnum + 1) * sizeof(Elf64_Shdr));
std::memmove(_data.data() + new_shoff, _data.data() + shoff, shnum * sizeof(Elf64_Shdr));
// Update section header table offset
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shoff), new_shoff);
_shdr_table = reinterpret_cast<Elf64_Shdr*>(_data.data() + new_shoff);
strtab_shdr = &_shdr_table[shstrndx];
} else {
// Expand data for new section header
_data.resize(_data.size() + sizeof(Elf64_Shdr));
// Refresh pointers after reallocation
_ehdr = reinterpret_cast<Elf64_Ehdr*>(_data.data());
shoff = read_value<uint64_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shoff));
_shdr_table = reinterpret_cast<Elf64_Shdr*>(_data.data() + shoff);
strtab_shdr = &_shdr_table[shstrndx];
}
// Update string table size and content
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&strtab_shdr->sh_size), new_strtab_size);
std::memcpy(_data.data() + strtab_offset, _string_table.c_str(), new_strtab_size);
}
// Create new .interp section header
Elf64_Shdr new_section = {};
write_value<uint32_t>(reinterpret_cast<uint8_t*>(&new_section.sh_name), name_offset);
write_value<uint32_t>(reinterpret_cast<uint8_t*>(&new_section.sh_type), SHT_PROGBITS);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_flags), SHF_ALLOC);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_addr), 0);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_offset), offset);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_size), size);
write_value<uint32_t>(reinterpret_cast<uint8_t*>(&new_section.sh_link), 0);
write_value<uint32_t>(reinterpret_cast<uint8_t*>(&new_section.sh_info), 0);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_addralign), 1);
write_value<uint64_t>(reinterpret_cast<uint8_t*>(&new_section.sh_entsize), 0);
// Add section header to table
std::memcpy(_data.data() + shoff + shnum * sizeof(Elf64_Shdr), &new_section, sizeof(Elf64_Shdr));
// Update section count
write_value<uint16_t>(reinterpret_cast<uint8_t*>(&_ehdr->e_shnum), shnum + 1);
}
public:
void save(const std::string& filename) {
// Get original file permissions from input file
struct stat file_stat;
if (stat(_input_filename.c_str(), &file_stat) != 0) {
throw std::runtime_error("Cannot get file permissions for: " + _input_filename);
}
std::string temp_filename = filename + ".tmp";
std::ofstream file(temp_filename, std::ios::binary);
if (!file) {
throw std::runtime_error("Cannot create output file: " + temp_filename);
}
file.write(reinterpret_cast<const char*>(_data.data()), _data.size());
if (!file) {
throw std::runtime_error("Error writing to file: " + temp_filename);
}
file.close();
// Set the same permissions on the temporary file
if (chmod(temp_filename.c_str(), file_stat.st_mode) != 0) {
std::remove(temp_filename.c_str());
throw std::runtime_error("Cannot set permissions on temp file: " + temp_filename);
}
// Atomically replace the original file
if (std::rename(temp_filename.c_str(), filename.c_str()) != 0) {
// Clean up temp file on failure
std::remove(temp_filename.c_str());
throw std::runtime_error("Cannot replace original file: " + filename);
}
}
};
void usage(const char* progname) {
std::cerr << "Usage: " << progname << " [OPTIONS] FILE\n";
std::cerr << "Options:\n";
std::cerr << " --set-interpreter INTERPRETER Set the ELF interpreter\n";
std::cerr << " --print-interpreter Print the current ELF interpreter\n";
std::cerr << " --output FILE Output file (default: modify in place)\n";
std::cerr << " --help Show this help\n";
}
int main(int argc, char* argv[]) {
static struct option long_options[] = {
{"set-interpreter", required_argument, 0, 's'},
{"print-interpreter", no_argument, 0, 'p'},
{"output", required_argument, 0, 'o'},
{"help", no_argument, 0, 'h'},
{0, 0, 0, 0}
};
std::string set_interp;
std::string output_file;
bool print_interp = false;
int c;
while ((c = getopt_long(argc, argv, "s:po:h", long_options, nullptr)) != -1) {
switch (c) {
case 's':
set_interp = optarg;
break;
case 'p':
print_interp = true;
break;
case 'o':
output_file = optarg;
break;
case 'h':
usage(argv[0]);
return 0;
default:
usage(argv[0]);
return 1;
}
}
if (optind >= argc) {
std::cerr << "Missing input file\n";
usage(argv[0]);
return 1;
}
if (set_interp.empty() && !print_interp) {
std::cerr << "Must specify either --set-interpreter or --print-interpreter\n";
usage(argv[0]);
return 1;
}
std::string input_file = argv[optind];
try {
elf_patcher patcher(input_file);
if (print_interp) {
std::string interp = patcher.get_current_interpreter();
if (interp.empty()) {
std::cerr << "No interpreter found\n";
return 1;
}
std::cout << interp << '\n';
}
if (!set_interp.empty()) {
patcher.set_interpreter(set_interp);
std::string save_filename = output_file.empty() ? input_file : output_file;
patcher.save(save_filename);
std::cout << "Set interpreter to: " << set_interp << '\n';
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << '\n';
return 1;
}
return 0;
}