Compare commits

..

11 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
aebf71b5a5 Optimize: fetch tablet map once per table group
Move get_tablet_map() call before the should_break check to avoid
fetching the same tablet map twice for each table group.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-12 15:10:13 +00:00
copilot-swe-agent[bot]
0f3bd7d91b Refactor: extract logging to method and improve format
- Extract logging logic into log_active_transitions() method with
  max_count as parameter for better reusability
- Change log format to "Active {kind} transition: ..." to make the
  transition kind more prominent at the beginning of the message

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-12 15:09:12 +00:00
copilot-swe-agent[bot]
c29d976e40 Use single info() statement for transition logging
Consolidated 4 separate rtlogger.info() calls into one that
conditionally appends leaving and pending replica information.
This makes the code cleaner while maintaining the same output.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-12 15:05:38 +00:00
copilot-swe-agent[bot]
b08bd1dac4 Optimize transition counting to avoid unnecessary iteration
After logging 5 transitions, use transitions().size() to efficiently
count remaining transitions in other table groups instead of iterating
through each one. This improves performance when there are many active
transitions.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-12 13:27:46 +00:00
copilot-swe-agent[bot]
9bd784f15c Limit transition logging to 5, check optionals, remove yields
- Only log leaving/pending replicas if they exist (check optional values)
- Limit logging to first 5 transitions, print "... and N more" if more exist
- Remove co_await coroutine::maybe_yield() calls as requested
- Handle all combinations of leaving/pending presence in log messages

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-12 13:26:44 +00:00
copilot-swe-agent[bot]
6b27f8002b Add clarifying comment about logging only involved replicas
Add comment explaining why we log only leaving/pending replicas
rather than full replica sets - to focus on what's actually changing
in the transition.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-11 23:56:20 +00:00
copilot-swe-agent[bot]
cd62a4d669 Log only leaving and pending replicas in transition
Changed the transition logging to print only the replicas involved in
the transition (leaving and pending) instead of all current and next
replicas. This makes the log output more concise and focused on what's
actually changing.

Example output:
Active transition: tablet=<uuid>/<id>, kind=migration, stage=streaming,
  leaving=<host1>:0, pending=<host2>:1

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-11 23:55:40 +00:00
copilot-swe-agent[bot]
f1543e9075 Mark unused 'tables' variable with [[maybe_unused]]
Address code review feedback by marking the 'tables' variable as
[[maybe_unused]] since it's part of the structured binding from
all_table_groups() but not used in the logging loop.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-11 23:51:10 +00:00
copilot-swe-agent[bot]
b583dc6a72 Add coroutine yield points to transition logging loop
Add co_await coroutine::maybe_yield() at each iteration to prevent
reactor stalls when iterating over potentially large numbers of
tablet transitions.

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-11 23:50:21 +00:00
copilot-swe-agent[bot]
d4a041da59 Add logging of active tablet transitions before sleep
Log detailed information about each active tablet transition when the
topology coordinator goes to sleep with active transitions. This includes:
- Tablet ID (table + tablet)
- Transition kind (migration, rebuild, etc.)
- Transition stage
- Current replicas
- Next replicas (destination of the transition)

Co-authored-by: tgrabiec <283695+tgrabiec@users.noreply.github.com>
2026-02-11 23:49:47 +00:00
copilot-swe-agent[bot]
fa0f23a863 Initial plan 2026-02-11 23:45:57 +00:00
77 changed files with 585 additions and 2146 deletions

View File

@@ -1,22 +0,0 @@
name: Sync Jira Based on PR Milestone Events
on:
pull_request_target:
types: [milestoned, demilestoned]
permissions:
contents: read
pull-requests: read
jobs:
jira-sync-milestone-set:
if: github.event.action == 'milestoned'
uses: scylladb/github-automation/.github/workflows/main_jira_sync_pr_milestone_set.yml@main
secrets:
caller_jira_auth: ${{ secrets.USER_AND_KEY_FOR_JIRA_AUTOMATION }}
jira-sync-milestone-removed:
if: github.event.action == 'demilestoned'
uses: scylladb/github-automation/.github/workflows/main_jira_sync_pr_milestone_removed.yml@main
secrets:
caller_jira_auth: ${{ secrets.USER_AND_KEY_FOR_JIRA_AUTOMATION }}

View File

@@ -1,4 +1,4 @@
name: Call Jira release creation for new milestone
name: Call Jira release creation for new milestone
on:
milestone:
@@ -9,6 +9,6 @@ jobs:
uses: scylladb/github-automation/.github/workflows/main_sync_milestone_to_jira_release.yml@main
with:
# Comma-separated list of Jira project keys
jira_project_keys: "SCYLLADB,CUSTOMER,SMI"
jira_project_keys: "SCYLLADB,CUSTOMER"
secrets:
caller_jira_auth: ${{ secrets.USER_AND_KEY_FOR_JIRA_AUTOMATION }}

View File

@@ -1,62 +0,0 @@
name: Close issues created by Scylla associates
on:
issues:
types: [opened, reopened]
permissions:
issues: write
jobs:
comment-and-close:
runs-on: ubuntu-latest
steps:
- name: Comment and close if author email is scylladb.com
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const issue = context.payload.issue;
const actor = context.actor;
// Get user data (only public email is available)
const { data: user } = await github.rest.users.getByUsername({
username: actor,
});
const email = user.email || "";
console.log(`Actor: ${actor}, public email: ${email || "<none>"}`);
// Only continue if email exists and ends with @scylladb.com
if (!email || !email.toLowerCase().endsWith("@scylladb.com")) {
console.log("User is not a scylladb.com email (or email not public); skipping.");
return;
}
const owner = context.repo.owner;
const repo = context.repo.repo;
const issue_number = issue.number;
const body = "Issues in this repository are closed automatically. Scylla associates should use Jira to manage issues.\nPlease move this issue to Jira https://scylladb.atlassian.net/jira/software/c/projects/SCYLLADB/list";
// Add the comment
await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
console.log(`Comment added to #${issue_number}`);
// Close the issue
await github.rest.issues.update({
owner,
repo,
issue_number,
state: "closed",
state_reason: "not_planned"
});
console.log(`Issue #${issue_number} closed.`);

View File

@@ -9,28 +9,10 @@ on:
jobs:
trigger-jenkins:
if: (github.event_name == 'issue_comment' && github.event.comment.user.login != 'scylladbbot') || github.event.label.name == 'conflicts'
if: (github.event.comment.user.login != 'scylladbbot' && contains(github.event.comment.body, '@scylladbbot') && contains(github.event.comment.body, 'trigger-ci')) || github.event.label.name == 'conflicts'
runs-on: ubuntu-latest
steps:
- name: Validate Comment Trigger
if: github.event_name == 'issue_comment'
id: verify_comment
shell: bash
run: |
BODY=$(cat << 'EOF'
${{ github.event.comment.body }}
EOF
)
CLEAN_BODY=$(echo "$BODY" | grep -v '^[[:space:]]*>')
if echo "$CLEAN_BODY" | grep -qi '@scylladbbot' && echo "$CLEAN_BODY" | grep -qi 'trigger-ci'; then
echo "trigger=true" >> $GITHUB_OUTPUT
else
echo "trigger=false" >> $GITHUB_OUTPUT
fi
- name: Trigger Scylla-CI-Route Jenkins Job
if: github.event_name == 'pull_request_target' || steps.verify_comment.outputs.trigger == 'true'
env:
JENKINS_USER: ${{ secrets.JENKINS_USERNAME }}
JENKINS_API_TOKEN: ${{ secrets.JENKINS_TOKEN }}

View File

@@ -1,197 +0,0 @@
# Implementation Summary: Error Injection Event Stream
## Problem Statement
Tests using error injections had to rely on log parsing to detect when injection points were hit:
```python
mark, _ = await log.wait_for('topology_coordinator_pause_before_processing_backlog: waiting', from_mark=mark)
```
This approach was:
- **Slow**: Required waiting for log flushes and buffer processing
- **Unreliable**: Regex matching could fail or match wrong lines
- **Fragile**: Changes to log messages broke tests
## Solution
Implemented a Server-Sent Events (SSE) API that sends real-time notifications when error injection points are triggered.
## Implementation
### 1. Backend Event System (`utils/error_injection.hh`)
**Added**:
- `error_injection_event_callback` type for event notifications
- `_event_callbacks` vector to store registered callbacks
- `notify_event()` method called by all `inject()` methods
- `register_event_callback()` / `clear_event_callbacks()` methods
- Cross-shard registration via `register_event_callback_on_all()`
**Modified**:
- All `inject()` methods now call `notify_event()` after logging
- Changed log level from DEBUG to INFO for better visibility
- Both enabled/disabled template specializations updated
### 2. SSE API Endpoint (`api/error_injection.cc`)
**Added**:
- `GET /v2/error_injection/events` endpoint
- Streams events in SSE format: `data: {"injection":"name","type":"handler","shard":0}\n\n`
- Cross-shard event collection using `foreign_ptr` and `smp::submit_to()`
- Automatic cleanup on client disconnect
**Architecture**:
1. Client connects → queue created on handler shard
2. Callbacks registered on ALL shards
3. When injection fires → event sent via `smp::submit_to()` to queue
4. Queue → SSE stream → client
5. Client disconnect → callbacks cleared on all shards
### 3. Python Client (`test/pylib/rest_client.py`)
**Added**:
- `InjectionEventStream` class:
- `wait_for_injection(name, timeout)` - wait for specific injection
- Background task reads SSE stream
- Queue-based event delivery
- `injection_event_stream()` context manager for lifecycle
- Full async/await support
**Usage**:
```python
async with injection_event_stream(server_ip) as stream:
await api.enable_injection(server_ip, "my_injection", one_shot=True)
# ... trigger operation ...
event = await stream.wait_for_injection("my_injection", timeout=30)
```
### 4. Tests (`test/cluster/test_error_injection_events.py`)
**Added**:
- `test_injection_event_stream_basic` - basic functionality
- `test_injection_event_stream_multiple_injections` - multiple tracking
- `test_injection_event_vs_log_parsing_comparison` - old vs new
### 5. Documentation (`docs/dev/error_injection_events.md`)
Complete documentation covering:
- Architecture and design
- Usage examples
- Migration guide from log parsing
- Thread safety and cleanup
## Key Design Decisions
### Why SSE instead of WebSocket?
- **Unidirectional**: We only need server → client events
- **Simpler**: Built on HTTP, easier to implement
- **Standard**: Well-supported in Python (aiohttp)
- **Sufficient**: No need for bidirectional communication
### Why Thread-Local Callbacks?
- **Performance**: No cross-shard synchronization overhead
- **Simplicity**: Each shard independent
- **Safety**: No shared mutable state
- Event delivery handled by `smp::submit_to()`
### Why Info Level Logging?
- **Visibility**: Events should be visible in logs AND via SSE
- **Debugging**: Easier to correlate events with log context
- **Consistency**: Matches importance of injection triggers
## Benefits
### Performance
- **Instant notification**: No waiting for log flushes
- **No regex matching**: Direct event delivery
- **Parallel processing**: Events from all shards
### Reliability
- **Type-safe**: Structured JSON events
- **No missed events**: Queue-based delivery
- **Automatic cleanup**: RAII ensures no leaks
### Developer Experience
- **Clean API**: Simple async/await pattern
- **Better errors**: Timeout on specific injection name
- **Metadata**: Event includes type and shard ID
- **Backward compatible**: Existing tests unchanged
## Testing
### Security
✅ CodeQL scan: **0 alerts** (Python)
### Validation Needed
Due to build environment limitations, the following validations are recommended:
- [ ] Build C++ code in dev mode
- [ ] Run example tests: `./test.py --mode=dev test/cluster/test_error_injection_events.py`
- [ ] Verify SSE connection lifecycle (connect, disconnect, reconnect)
- [ ] Test with multiple concurrent clients
- [ ] Verify cross-shard event delivery
- [ ] Performance comparison with log parsing
## Files Changed
```
api/api-doc/error_injection.json | 15 +++
api/error_injection.cc | 82 ++++++++++++++
docs/dev/error_injection_events.md | 132 +++++++++++++++++++++
test/cluster/test_error_injection_events.py | 140 ++++++++++++++++++++++
test/pylib/rest_client.py | 144 ++++++++++++++++++++++
utils/error_injection.hh | 81 +++++++++++++
6 files changed, 587 insertions(+), 7 deletions(-)
```
## Migration Guide
### Old Approach
```python
log = await manager.server_open_log(server.server_id)
mark = await log.mark()
await manager.api.enable_injection(server.ip_addr, "my_injection", one_shot=True)
# ... trigger operation ...
mark, _ = await log.wait_for('my_injection: waiting', from_mark=mark)
```
### New Approach
```python
async with injection_event_stream(server.ip_addr) as stream:
await manager.api.enable_injection(server.ip_addr, "my_injection", one_shot=True)
# ... trigger operation ...
event = await stream.wait_for_injection("my_injection", timeout=30)
```
### Backward Compatibility
- ✅ All existing log-based tests continue to work
- ✅ Logging still happens (now at INFO level)
- ✅ No breaking changes to existing APIs
- ✅ SSE is opt-in for new tests
## Future Enhancements
Possible improvements:
1. Server-side filtering by injection name (query parameter)
2. Include injection parameters in events
3. Add event timestamps
4. Event history/replay support
5. Multiple concurrent SSE clients per server
6. WebSocket support if bidirectional communication needed
## Conclusion
This implementation successfully addresses the problem statement:
- ✅ Eliminates log parsing
- ✅ Faster tests
- ✅ More reliable detection
- ✅ Clean API
- ✅ Backward compatible
- ✅ Well documented
- ✅ Security validated
The solution follows ScyllaDB best practices:
- RAII for resource management
- Seastar async patterns (coroutines, futures)
- Cross-shard communication via `smp::submit_to()`
- Thread-local state, no locks
- Comprehensive error handling

View File

@@ -112,21 +112,6 @@
}
]
},
{
"path":"/v2/error_injection/events",
"operations":[
{
"method":"GET",
"summary":"Subscribe to Server-Sent Events stream of error injection events",
"type":"void",
"nickname":"injection_events",
"produces":[
"text/event-stream"
],
"parameters":[]
}
]
},
{
"path":"/v2/error_injection/disconnect/{ip}",
"operations":[

View File

@@ -13,22 +13,12 @@
#include "utils/rjson.hh"
#include <seastar/core/future-util.hh>
#include <seastar/util/short_streams.hh>
#include <seastar/core/queue.hh>
#include <seastar/core/when_all.hh>
#include <seastar/core/sharded.hh>
namespace api {
using namespace seastar::httpd;
namespace hf = httpd::error_injection_json;
// Structure to hold error injection event data
struct injection_event {
sstring injection_name;
sstring injection_type;
unsigned shard_id;
};
void set_error_injection(http_context& ctx, routes& r) {
hf::enable_injection.set(r, [](std::unique_ptr<request> req) -> future<json::json_return_type> {
@@ -111,79 +101,6 @@ void set_error_injection(http_context& ctx, routes& r) {
return make_ready_future<json::json_return_type>(json::json_void());
});
});
// Server-Sent Events endpoint for injection events
// This allows clients to subscribe to real-time injection events instead of log parsing
r.add(operation_type::GET, url("/v2/error_injection/events"), [](std::unique_ptr<request> req) -> future<json::json_return_type> {
// Create a shared foreign_ptr to a queue that will receive events from all shards
// Using a queue on the current shard to collect events
using event_queue_t = seastar::queue<injection_event>;
auto event_queue = make_lw_shared<event_queue_t>();
auto queue_ptr = make_foreign(event_queue);
// Register callback on all shards to send events to our queue
auto& errinj = utils::get_local_injector();
// Capture the current shard ID for event delivery
auto target_shard = this_shard_id();
// Setup event callback that forwards events to the queue on the target shard
// Note: We use shared_ptr wrapper for foreign_ptr to make it copyable
auto callback = [queue_ptr = queue_ptr.copy(), target_shard] (std::string_view name, std::string_view type) {
injection_event evt{
.injection_name = sstring(name),
.injection_type = sstring(type),
.shard_id = this_shard_id()
};
// Send event to the target shard's queue (discard future, fire-and-forget)
(void)smp::submit_to(target_shard, [queue_ptr = queue_ptr.copy(), evt = std::move(evt)] () mutable {
return queue_ptr->push_eventually(std::move(evt));
});
};
// Register the callback on all shards
co_await errinj.register_event_callback_on_all(callback);
// Return a streaming function that sends SSE events
noncopyable_function<future<>(output_stream<char>&&)> stream_func =
[event_queue](output_stream<char>&& os) -> future<> {
auto s = std::move(os);
std::exception_ptr ex;
try {
// Send initial SSE comment to establish connection
co_await s.write(": connected\n\n");
co_await s.flush();
// Stream events as they arrive from any shard
while (true) {
auto evt = co_await event_queue->pop_eventually();
// Format as SSE event
// data: {"injection":"name","type":"handler","shard":0}
auto json_data = format("{{\"injection\":\"{}\",\"type\":\"{}\",\"shard\":{}}}",
evt.injection_name, evt.injection_type, evt.shard_id);
co_await s.write(format("data: {}\n\n", json_data));
co_await s.flush();
}
} catch (...) {
ex = std::current_exception();
}
// Cleanup: clear callbacks on all shards
co_await utils::get_local_injector().clear_event_callbacks_on_all();
co_await s.close();
if (ex) {
co_await coroutine::return_exception_ptr(std::move(ex));
}
};
co_return json::json_return_type(std::move(stream_func));
});
}
} // namespace api

View File

@@ -515,15 +515,6 @@ void set_sstables_loader(http_context& ctx, routes& r, sharded<sstables_loader>&
auto sstables = parsed.GetArray() |
std::views::transform([] (const auto& s) { return sstring(rjson::to_string_view(s)); }) |
std::ranges::to<std::vector>();
apilog.info("Restore invoked with following parameters: keyspace={}, table={}, endpoint={}, bucket={}, prefix={}, sstables_count={}, scope={}, primary_replica_only={}",
keyspace,
table,
endpoint,
bucket,
prefix,
sstables.size(),
scope,
primary_replica_only);
auto task_id = co_await sst_loader.local().download_new_sstables(keyspace, table, prefix, std::move(sstables), endpoint, bucket, scope, primary_replica_only);
co_return json::json_return_type(fmt::to_string(task_id));
});

View File

@@ -1174,7 +1174,6 @@ scylla_core = (['message/messaging_service.cc',
'utils/gz/crc_combine.cc',
'utils/gz/crc_combine_table.cc',
'utils/http.cc',
'utils/http_client_error_processing.cc',
'utils/rest/client.cc',
'utils/s3/aws_error.cc',
'utils/s3/client.cc',

View File

@@ -434,6 +434,7 @@ unaliasedSelector returns [uexpression tmp]
| K_TTL '(' c=cident ')' { tmp = column_mutation_attribute{column_mutation_attribute::attribute_kind::ttl,
unresolved_identifier{std::move(c)}}; }
| f=functionName args=selectionFunctionArgs { tmp = function_call{std::move(f), std::move(args)}; }
| f=similarityFunctionName args=vectorSimilarityArgs { tmp = function_call{std::move(f), std::move(args)}; }
| K_CAST '(' arg=unaliasedSelector K_AS t=native_type ')' { tmp = cast{.style = cast::cast_style::sql, .arg = std::move(arg), .type = std::move(t)}; }
)
( '.' fi=cident { tmp = field_selection{std::move(tmp), std::move(fi)}; }
@@ -448,6 +449,17 @@ selectionFunctionArgs returns [std::vector<expression> a]
')'
;
vectorSimilarityArgs returns [std::vector<expression> a]
: '(' ')'
| '(' v1=vectorSimilarityArg { a.push_back(std::move(v1)); }
( ',' vn=vectorSimilarityArg { a.push_back(std::move(vn)); } )*
')'
;
vectorSimilarityArg returns [uexpression a]
: s=unaliasedSelector { a = std::move(s); }
;
countArgument
: '*'
/* COUNT(1) is also allowed, it is recognized via the general function(args) path */
@@ -1694,6 +1706,10 @@ functionName returns [cql3::functions::function_name s]
: (ks=keyspaceName '.')? f=allowedFunctionName { $s.keyspace = std::move(ks); $s.name = std::move(f); }
;
similarityFunctionName returns [cql3::functions::function_name s]
: f=allowedSimilarityFunctionName { $s = cql3::functions::function_name::native_function(std::move(f)); }
;
allowedFunctionName returns [sstring s]
: f=IDENT { $s = $f.text; std::transform(s.begin(), s.end(), s.begin(), ::tolower); }
| f=QUOTED_NAME { $s = $f.text; }
@@ -1702,6 +1718,11 @@ allowedFunctionName returns [sstring s]
| K_COUNT { $s = "count"; }
;
allowedSimilarityFunctionName returns [sstring s]
: f=(K_SIMILARITY_COSINE | K_SIMILARITY_EUCLIDEAN | K_SIMILARITY_DOT_PRODUCT)
{ $s = $f.text; std::transform(s.begin(), s.end(), s.begin(), ::tolower); }
;
functionArgs returns [std::vector<expression> a]
: '(' ')'
| '(' t1=term { a.push_back(std::move(t1)); }
@@ -2398,6 +2419,10 @@ K_MUTATION_FRAGMENTS: M U T A T I O N '_' F R A G M E N T S;
K_VECTOR_SEARCH_INDEXING: V E C T O R '_' S E A R C H '_' I N D E X I N G;
K_SIMILARITY_EUCLIDEAN: S I M I L A R I T Y '_' E U C L I D E A N;
K_SIMILARITY_COSINE: S I M I L A R I T Y '_' C O S I N E;
K_SIMILARITY_DOT_PRODUCT: S I M I L A R I T Y '_' D O T '_' P R O D U C T;
// Case-insensitive alpha characters
fragment A: ('a'|'A');
fragment B: ('b'|'B');

View File

@@ -10,41 +10,9 @@
#include "types/types.hh"
#include "types/vector.hh"
#include "exceptions/exceptions.hh"
#include <span>
#include <bit>
namespace cql3 {
namespace functions {
namespace detail {
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension) {
if (!param) {
throw exceptions::invalid_request_exception("Cannot extract float vector from null parameter");
}
const size_t expected_size = dimension * sizeof(float);
if (param->size() != expected_size) {
throw exceptions::invalid_request_exception(
fmt::format("Invalid vector size: expected {} bytes for {} floats, got {} bytes",
expected_size, dimension, param->size()));
}
std::vector<float> result;
result.reserve(dimension);
bytes_view view(*param);
for (size_t i = 0; i < dimension; ++i) {
// read_simple handles network byte order (big-endian) conversion
uint32_t raw = read_simple<uint32_t>(view);
result.push_back(std::bit_cast<float>(raw));
}
return result;
}
} // namespace detail
namespace {
// The computations of similarity scores match the exact formulas of Cassandra's (jVector's) implementation to ensure compatibility.
@@ -54,14 +22,14 @@ namespace {
// You should only use this function if you need to preserve the original vectors and cannot normalize
// them in advance.
float compute_cosine_similarity(std::span<const float> v1, std::span<const float> v2) {
float compute_cosine_similarity(const std::vector<data_value>& v1, const std::vector<data_value>& v2) {
double dot_product = 0.0;
double squared_norm_a = 0.0;
double squared_norm_b = 0.0;
for (size_t i = 0; i < v1.size(); ++i) {
double a = v1[i];
double b = v2[i];
double a = value_cast<float>(v1[i]);
double b = value_cast<float>(v2[i]);
dot_product += a * b;
squared_norm_a += a * a;
@@ -78,12 +46,12 @@ float compute_cosine_similarity(std::span<const float> v1, std::span<const float
return (1 + (dot_product / (std::sqrt(squared_norm_a * squared_norm_b)))) / 2;
}
float compute_euclidean_similarity(std::span<const float> v1, std::span<const float> v2) {
float compute_euclidean_similarity(const std::vector<data_value>& v1, const std::vector<data_value>& v2) {
double sum = 0.0;
for (size_t i = 0; i < v1.size(); ++i) {
double a = v1[i];
double b = v2[i];
double a = value_cast<float>(v1[i]);
double b = value_cast<float>(v2[i]);
double diff = a - b;
sum += diff * diff;
@@ -97,12 +65,12 @@ float compute_euclidean_similarity(std::span<const float> v1, std::span<const fl
// Assumes that both vectors are L2-normalized.
// This similarity is intended as an optimized way to perform cosine similarity calculation.
float compute_dot_product_similarity(std::span<const float> v1, std::span<const float> v2) {
float compute_dot_product_similarity(const std::vector<data_value>& v1, const std::vector<data_value>& v2) {
double dot_product = 0.0;
for (size_t i = 0; i < v1.size(); ++i) {
double a = v1[i];
double b = v2[i];
double a = value_cast<float>(v1[i]);
double b = value_cast<float>(v2[i]);
dot_product += a * b;
}
@@ -168,15 +136,13 @@ bytes_opt vector_similarity_fct::execute(std::span<const bytes_opt> parameters)
return std::nullopt;
}
// Extract dimension from the vector type
const auto& type = static_cast<const vector_type_impl&>(*arg_types()[0]);
size_t dimension = type.get_dimension();
const auto& type = arg_types()[0];
data_value v1 = type->deserialize(*parameters[0]);
data_value v2 = type->deserialize(*parameters[1]);
const auto& v1_elements = value_cast<std::vector<data_value>>(v1);
const auto& v2_elements = value_cast<std::vector<data_value>>(v2);
// Optimized path: extract floats directly from bytes, bypassing data_value overhead
std::vector<float> v1 = detail::extract_float_vector(parameters[0], dimension);
std::vector<float> v2 = detail::extract_float_vector(parameters[1], dimension);
float result = SIMILARITY_FUNCTIONS.at(_name)(v1, v2);
float result = SIMILARITY_FUNCTIONS.at(_name)(v1_elements, v2_elements);
return float_type->decompose(result);
}

View File

@@ -11,7 +11,6 @@
#include "native_scalar_function.hh"
#include "cql3/assignment_testable.hh"
#include "cql3/functions/function_name.hh"
#include <span>
namespace cql3 {
namespace functions {
@@ -20,7 +19,7 @@ static const function_name SIMILARITY_COSINE_FUNCTION_NAME = function_name::nati
static const function_name SIMILARITY_EUCLIDEAN_FUNCTION_NAME = function_name::native_function("similarity_euclidean");
static const function_name SIMILARITY_DOT_PRODUCT_FUNCTION_NAME = function_name::native_function("similarity_dot_product");
using similarity_function_t = float (*)(std::span<const float>, std::span<const float>);
using similarity_function_t = float (*)(const std::vector<data_value>&, const std::vector<data_value>&);
extern thread_local const std::unordered_map<function_name, similarity_function_t> SIMILARITY_FUNCTIONS;
std::vector<data_type> retrieve_vector_arg_types(const function_name& name, const std::vector<shared_ptr<assignment_testable>>& provided_args);
@@ -34,14 +33,5 @@ public:
virtual bytes_opt execute(std::span<const bytes_opt> parameters) override;
};
namespace detail {
// Extract float vector directly from serialized bytes, bypassing data_value overhead.
// This is an internal API exposed for testing purposes.
// Vector<float, N> wire format: N floats as big-endian uint32_t values, 4 bytes each.
std::vector<float> extract_float_vector(const bytes_opt& param, size_t dimension);
} // namespace detail
} // namespace functions
} // namespace cql3

View File

@@ -1986,13 +1986,13 @@ future<> db::commitlog::segment_manager::replenish_reserve() {
}
continue;
} catch (shutdown_marker&) {
_reserve_segments.abort(std::current_exception());
break;
} catch (...) {
clogger.warn("Exception in segment reservation: {}", std::current_exception());
}
co_await sleep(100ms);
}
_reserve_segments.abort(std::make_exception_ptr(shutdown_marker()));
}
future<std::vector<db::commitlog::descriptor>>

View File

@@ -1,132 +0,0 @@
# Error Injection Event Stream Implementation
## Overview
This implementation adds Server-Sent Events (SSE) support for error injection points, allowing tests to wait for injections to be triggered without log parsing.
## Architecture
### Backend (C++)
#### 1. Event Notification System (`utils/error_injection.hh`)
- **Callback Type**: `error_injection_event_callback` - function signature: `void(std::string_view injection_name, std::string_view injection_type)`
- **Storage**: Thread-local vector of callbacks (`_event_callbacks`)
- **Notification**: When any `inject()` method is called, `notify_event()` triggers all registered callbacks
- **Thread Safety**: Each shard has its own error_injection instance with its own callbacks
- **Cross-Shard**: Static methods use `smp::invoke_on_all()` to register callbacks on all shards
#### 2. SSE Endpoint (`api/error_injection.cc`)
```
GET /v2/error_injection/events
Content-Type: text/event-stream
```
**Flow**:
1. Client connects to SSE endpoint
2. Server creates a queue on the current shard
3. Callback registered on ALL shards that forwards events to this queue (using `smp::submit_to`)
4. Server streams events in SSE format: `data: {"injection":"name","type":"handler","shard":0}\n\n`
5. On disconnect (client closes or exception), callbacks are cleaned up
**Event Format**:
```json
{
"injection": "injection_name",
"type": "sleep|handler|exception|lambda",
"shard": 0
}
```
### Python Client (`test/pylib/rest_client.py`)
#### InjectionEventStream Class
```python
async with injection_event_stream(node_ip) as stream:
event = await stream.wait_for_injection("my_injection", timeout=30)
```
**Features**:
- Async context manager for automatic connection/disconnection
- Background task reads SSE events
- Queue-based event delivery
- `wait_for_injection()` method filters events by injection name
## Usage Examples
### Basic Usage
```python
async with injection_event_stream(server_ip) as event_stream:
# Enable injection
await api.enable_injection(server_ip, "my_injection", one_shot=True)
# Trigger operation that hits injection
# ... some operation ...
# Wait for injection without log parsing!
event = await event_stream.wait_for_injection("my_injection", timeout=30)
logger.info(f"Injection hit on shard {event['shard']}")
```
### Old vs New Approach
**Old (Log Parsing)**:
```python
log = await manager.server_open_log(server_id)
mark = await log.mark()
await api.enable_injection(ip, "my_injection", one_shot=True)
# ... operation ...
mark, _ = await log.wait_for('my_injection: waiting', from_mark=mark)
```
**New (Event Stream)**:
```python
async with injection_event_stream(ip) as stream:
await api.enable_injection(ip, "my_injection", one_shot=True)
# ... operation ...
event = await stream.wait_for_injection("my_injection", timeout=30)
```
## Benefits
1. **Performance**: No waiting for log flushes or buffer processing
2. **Reliability**: Direct event notifications, no regex matching failures
3. **Simplicity**: Clean async/await pattern
4. **Flexibility**: Can wait for multiple injections, get event metadata
5. **Backward Compatible**: Existing log-based tests continue to work
## Implementation Notes
### Thread Safety
- Each shard has independent error_injection instance
- Events from any shard are delivered to SSE client via `smp::submit_to`
- Queue operations are shard-local, avoiding cross-shard synchronization
### Cleanup
- Client disconnect triggers callback cleanup on all shards
- Cleanup happens automatically via RAII (try/finally in stream function)
- No callback leaks even if client disconnects abruptly
### Logging
- Injection triggers now log at INFO level (was DEBUG)
- This ensures events are visible in logs AND via SSE
- SSE provides machine-readable events, logs provide human-readable context
## Testing
See `test/cluster/test_error_injection_events.py` for example tests:
- `test_injection_event_stream_basic`: Basic functionality
- `test_injection_event_stream_multiple_injections`: Multiple injection tracking
- `test_injection_event_vs_log_parsing_comparison`: Old vs new comparison
## Future Enhancements
Possible improvements:
1. Filter events by injection name at server side (query parameter)
2. Include injection parameters in events
3. Add event timestamps
4. Support for event history/replay
5. WebSocket support (if bidirectional communication needed)

10
init.cc
View File

@@ -11,6 +11,7 @@
#include "seastarx.hh"
#include "db/config.hh"
#include <boost/algorithm/string/trim.hpp>
#include <seastar/core/coroutine.hh>
#include "sstables/sstable_compressor_factory.hh"
#include "gms/feature_service.hh"
@@ -29,7 +30,11 @@ std::set<gms::inet_address> get_seeds_from_db_config(const db::config& cfg,
std::set<gms::inet_address> seeds;
if (seed_provider.parameters.contains("seeds")) {
for (const auto& seed : utils::split_comma_separated_list(seed_provider.parameters.at("seeds"))) {
size_t begin = 0;
size_t next = 0;
sstring seeds_str = seed_provider.parameters.find("seeds")->second;
while (begin < seeds_str.length() && begin != (next=seeds_str.find(",",begin))) {
auto seed = boost::trim_copy(seeds_str.substr(begin,next-begin));
try {
seeds.emplace(gms::inet_address::lookup(seed, family, preferred).get());
} catch (...) {
@@ -41,10 +46,11 @@ std::set<gms::inet_address> get_seeds_from_db_config(const db::config& cfg,
seed,
std::current_exception());
}
begin = next+1;
}
}
if (seeds.empty()) {
seeds.emplace("127.0.0.1");
seeds.emplace(gms::inet_address("127.0.0.1"));
}
startlog.info("seeds={{{}}}, listen_address={}, broadcast_address={}",
fmt::join(seeds, ", "), listen, broadcast_address);

View File

@@ -157,10 +157,7 @@ fedora_packages=(
podman
buildah
# for cassandra-stress
java-openjdk-headless
snappy
https://github.com/scylladb/cassandra-stress/releases/download/v3.18.1/cassandra-stress-java21-3.18.1-1.noarch.rpm
elfutils
jq
@@ -389,10 +386,6 @@ elif [ "$ID" = "fedora" ]; then
exit 1
fi
dnf install -y "${fedora_packages[@]}" "${fedora_python3_packages[@]}"
# Fedora 45 tightened key checks, and cassandra-stress is not signed yet.
dnf install --no-gpgchecks -y https://github.com/scylladb/cassandra-stress/releases/download/v3.18.1/cassandra-stress-java21-3.18.1-1.noarch.rpm
PIP_DEFAULT_ARGS="--only-binary=:all: -v"
pip_constrained_packages=""
for package in "${!pip_packages[@]}"

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9034610470ff645fab03da5ad6c690e5b41f3307ea4b529c7e63b0786a1289ed
size 6539600
oid sha256:cb48c6afc5bf2a62234e069c8dfc6ae491645f7fb200072bb73dac148349c472
size 6543556

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c4bbf51dbe01d684ea5b9a9157781988ed499604d2fde90143bad0b9a5594f0
size 6543944
oid sha256:2433f7a1fc5cda0dd990ab59587eb6046dca0fe1ae48d599953d1936fe014ed9
size 6492176

Submodule seastar updated: d2953d2ad1...f55dc7ebed

View File

@@ -1441,6 +1441,43 @@ class topology_coordinator : public endpoint_lifecycle_subscriber
}
}
void log_active_transitions(size_t max_count) {
auto tm = get_token_metadata_ptr();
size_t logged_count = 0;
size_t total_count = 0;
bool should_break = false;
for (auto&& [base_table, tables [[maybe_unused]]] : tm->tablets().all_table_groups()) {
const auto& tmap = tm->tablets().get_tablet_map(base_table);
if (should_break) {
total_count += tmap.transitions().size();
continue;
}
for (auto&& [tablet, trinfo]: tmap.transitions()) {
total_count++;
if (logged_count < max_count) {
locator::global_tablet_id gid { base_table, tablet };
const auto& tinfo = tmap.get_tablet_info(tablet);
// Log only the replicas involved in the transition (leaving/pending)
// rather than all replicas, to focus on what's actually changing
auto leaving = locator::get_leaving_replica(tinfo, trinfo);
auto pending = trinfo.pending_replica;
rtlogger.info("Active {} transition: tablet={}, stage={}{}{}",
trinfo.transition, gid, trinfo.stage,
leaving ? fmt::format(", leaving={}", *leaving) : "",
pending ? fmt::format(", pending={}", *pending) : "");
logged_count++;
if (logged_count >= max_count) {
should_break = true;
break;
}
}
}
}
if (total_count > max_count) {
rtlogger.info("... and {} more active transitions", total_count - max_count);
}
}
// When "drain" is true, we migrate tablets only as long as there are nodes to drain
// and then change the transition state to write_both_read_old. Also, while draining,
// we ignore pending topology requests which normally interrupt load balancing.
@@ -2026,6 +2063,7 @@ class topology_coordinator : public endpoint_lifecycle_subscriber
// to check atomically with event.wait()
if (!_tablets_ready) {
rtlogger.debug("Going to sleep with active tablet transitions");
log_active_transitions(5);
release_guard(std::move(guard));
co_await await_event();
}

View File

@@ -436,10 +436,7 @@ tablet_stream_files(netw::messaging_service& ms, std::list<stream_blob_info> sou
stream_options.buffer_size = file_stream_buffer_size;
stream_options.read_ahead = file_stream_read_ahead;
for (auto&& source_info : sources) {
// Keep stream_blob_info alive only at duration of streaming. Allowing the file descriptor
// of the sstable component to be released right after it has been streamed.
auto info = std::exchange(source_info, {});
for (auto& info : sources) {
auto& filename = info.filename;
std::optional<input_stream<char>> fstream;
bool fstream_closed = false;
@@ -620,7 +617,6 @@ tablet_stream_files(netw::messaging_service& ms, std::list<stream_blob_info> sou
ops_id, filename, targets, total_size, get_bw(total_size, start_time));
}
}
co_await utils::get_local_injector().inject("tablet_stream_files_end_wait", utils::wait_for_message(std::chrono::seconds(60)));
if (error) {
blogger.warn("fstream[{}] Master failed sending files_nr={} files={} targets={} send_size={} bw={} error={}",
ops_id, sources.size(), sources, targets, ops_total_size, get_bw(ops_total_size, ops_start_time), error);
@@ -684,20 +680,15 @@ future<stream_files_response> tablet_stream_files_handler(replica::database& db,
if (files.empty()) {
co_return resp;
}
auto sstable_nr = sstables.size();
// Release reference to sstables to be streamed here. Since one sstable is streamed at a time,
// a sstable - that has been compacted - can have its space released from disk right after
// that sstable's content has been fully streamed.
sstables.clear();
blogger.debug("stream_sstables[{}] Started sending sstable_nr={} files_nr={} files={} range={}",
req.ops_id, sstable_nr, files.size(), files, req.range);
req.ops_id, sstables.size(), files.size(), files, req.range);
auto ops_start_time = std::chrono::steady_clock::now();
auto files_nr = files.size();
size_t stream_bytes = co_await tablet_stream_files(ms, std::move(files), req.targets, req.table, req.ops_id, req.topo_guard);
resp.stream_bytes = stream_bytes;
auto duration = std::chrono::steady_clock::now() - ops_start_time;
blogger.info("stream_sstables[{}] Finished sending sstable_nr={} files_nr={} range={} stream_bytes={} stream_time={} stream_bw={}",
req.ops_id, sstable_nr, files_nr, req.range, stream_bytes, duration, get_bw(stream_bytes, ops_start_time));
req.ops_id, sstables.size(), files_nr, req.range, stream_bytes, duration, get_bw(stream_bytes, ops_start_time));
co_return resp;
}

View File

@@ -415,7 +415,7 @@ future<utils::chunked_vector<task_identity>> task_manager::virtual_task::impl::g
auto nodes = module->get_nodes();
co_await utils::get_local_injector().inject("tasks_vt_get_children", [] (auto& handler) -> future<> {
tmlogger.info("tasks_vt_get_children: waiting");
co_await handler.wait_for_message(std::chrono::steady_clock::now() + std::chrono::seconds{60});
co_await handler.wait_for_message(std::chrono::steady_clock::now() + std::chrono::seconds{10});
});
co_return co_await map_reduce(nodes, [ms, parent_id, is_host_alive = std::move(is_host_alive)] (auto host_id) -> future<utils::chunked_vector<task_identity>> {
if (is_host_alive(host_id)) {

View File

@@ -51,17 +51,17 @@ BOOST_AUTO_TEST_CASE(TestXmlErrorPayload) {
auto error = aws::aws_error::parse(build_xml_response("IncompleteSignatureException", message, requestId)).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::INCOMPLETE_SIGNATURE, error.get_error_type());
BOOST_REQUIRE_EQUAL(message, error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
error = aws::aws_error::parse(build_xml_response("InternalFailure", message, requestId, message_style::plural)).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::INTERNAL_FAILURE, error.get_error_type());
BOOST_REQUIRE_EQUAL(message, error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
error = aws::aws_error::parse(build_xml_response("IDontExist", message, requestId, message_style::plural)).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL(message, error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
auto no_error = aws::aws_error::parse("");
BOOST_REQUIRE_EQUAL(no_error.has_value(), false);
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(TestXmlErrorPayload) {
error = aws::aws_error::parse(response).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::INTERNAL_FAILURE, error.get_error_type());
BOOST_REQUIRE_EQUAL(message, error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
}
BOOST_AUTO_TEST_CASE(TestErrorsWithPrefixParse) {
@@ -92,7 +92,7 @@ BOOST_AUTO_TEST_CASE(TestErrorsWithPrefixParse) {
auto error = aws::aws_error::parse(build_xml_response(exceptionPrefix + "IDon'tExist", "JunkMessage", requestId)).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL("JunkMessage", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
}
BOOST_AUTO_TEST_CASE(TestErrorsWithoutPrefixParse) {
@@ -107,15 +107,7 @@ BOOST_AUTO_TEST_CASE(TestErrorsWithoutPrefixParse) {
auto error = aws::aws_error::parse(build_xml_response("IDon'tExist", "JunkMessage", requestId)).value();
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL("JunkMessage", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
}
BOOST_AUTO_TEST_CASE(TestHelperFunctions) {
BOOST_REQUIRE_EQUAL(utils::http::from_http_code(seastar::http::reply::status_type::service_unavailable), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(utils::http::from_http_code(seastar::http::reply::status_type::unauthorized), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(utils::http::from_system_error(std::system_error(ECONNRESET, std::system_category())), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(utils::http::from_system_error(std::system_error(EADDRINUSE, std::system_category())), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
}
BOOST_AUTO_TEST_CASE(TestNestedException) {
@@ -134,7 +126,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::NETWORK_CONNECTION, error.get_error_type());
BOOST_REQUIRE_EQUAL("Software caused connection abort", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
}
// Test nested exceptions where the innermost is NOT a system_error
@@ -148,7 +140,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL("Higher level runtime_error", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
}
// Test single exception which is NOT a nested exception
@@ -158,7 +150,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL("Something bad happened", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
}
// Test with non-std::exception
@@ -168,7 +160,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::UNKNOWN, error.get_error_type());
BOOST_REQUIRE_EQUAL("No error message was provided, exception content: char const*", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::no);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::no);
}
// Test system_error
@@ -178,7 +170,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::NETWORK_CONNECTION, error.get_error_type());
BOOST_REQUIRE_EQUAL("Software caused connection abort", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
}
// Test aws_exception
@@ -188,7 +180,7 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::HTTP_TOO_MANY_REQUESTS, error.get_error_type());
BOOST_REQUIRE_EQUAL("", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
}
// Test httpd::unexpected_status_error
@@ -198,6 +190,6 @@ BOOST_AUTO_TEST_CASE(TestNestedException) {
auto error = aws::aws_error::from_exception_ptr(std::current_exception());
BOOST_REQUIRE_EQUAL(aws::aws_error_type::HTTP_NETWORK_CONNECT_TIMEOUT, error.get_error_type());
BOOST_REQUIRE_EQUAL(" HTTP code: 599 Network Connect Timeout", error.get_error_message());
BOOST_REQUIRE_EQUAL(error.is_retryable(), utils::http::retryable::yes);
BOOST_REQUIRE_EQUAL(error.is_retryable(), aws::retryable::yes);
}
}

View File

@@ -29,7 +29,6 @@
#include "types/list.hh"
#include "types/set.hh"
#include "schema/schema_builder.hh"
#include "cql3/functions/vector_similarity_fcts.hh"
BOOST_AUTO_TEST_SUITE(cql_functions_test)
@@ -423,96 +422,4 @@ SEASTAR_TEST_CASE(test_aggregate_functions_vector_type) {
});
}
SEASTAR_THREAD_TEST_CASE(test_extract_float_vector) {
// Compare standard deserialization path vs optimized extraction path
auto serialize = [](size_t dim, const std::vector<float>& values) {
auto vector_type = vector_type_impl::get_instance(float_type, dim);
std::vector<data_value> data_vals;
data_vals.reserve(values.size());
for (float f : values) {
data_vals.push_back(data_value(f));
}
return vector_type->decompose(make_list_value(vector_type, data_vals));
};
auto deserialize_standard = [](size_t dim, const bytes_opt& serialized) {
auto vector_type = vector_type_impl::get_instance(float_type, dim);
data_value v = vector_type->deserialize(*serialized);
const auto& elements = value_cast<std::vector<data_value>>(v);
std::vector<float> result;
result.reserve(elements.size());
for (const auto& elem : elements) {
result.push_back(value_cast<float>(elem));
}
return result;
};
auto compare_vectors = [](const std::vector<float>& a, const std::vector<float>& b) {
BOOST_REQUIRE_EQUAL(a.size(), b.size());
for (size_t i = 0; i < a.size(); ++i) {
if (std::isnan(a[i]) && std::isnan(b[i])) {
continue; // Both NaN, consider equal
}
BOOST_REQUIRE_EQUAL(a[i], b[i]);
}
};
// Prepare test cases
std::vector<std::vector<float>> test_vectors = {
// Small vectors with explicit values
{1.0f, 2.5f},
{-1.5f, 0.0f, 3.14159f},
// Special floating-point values
{
std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity(),
0.0f,
-0.0f,
std::numeric_limits<float>::min(),
std::numeric_limits<float>::max()
},
// NaN values (require special comparison)
{
std::numeric_limits<float>::quiet_NaN(),
1.0f,
std::numeric_limits<float>::signaling_NaN()
}
};
// Add common embedding dimensions with pattern-generated data
for (size_t dim : {128, 384, 768, 1024, 1536}) {
std::vector<float> vec(dim);
for (size_t i = 0; i < dim; ++i) {
vec[i] = static_cast<float>(i % 100) * 0.01f;
}
test_vectors.push_back(std::move(vec));
}
// Run tests for all test vectors
for (const auto& vec : test_vectors) {
size_t dim = vec.size();
auto serialized = serialize(dim, vec);
auto standard = deserialize_standard(dim, serialized);
compare_vectors(standard, cql3::functions::detail::extract_float_vector(serialized, dim));
}
// Null parameter should throw
BOOST_REQUIRE_EXCEPTION(
cql3::functions::detail::extract_float_vector(std::nullopt, 3),
exceptions::invalid_request_exception,
seastar::testing::exception_predicate::message_contains("Cannot extract float vector from null parameter")
);
// Size mismatch should throw
for (auto [actual_dim, expected_dim] : {std::pair{2, 3}, {4, 3}}) {
std::vector<float> vec(actual_dim, 1.0f);
auto serialized = serialize(actual_dim, vec);
BOOST_REQUIRE_EXCEPTION(
cql3::functions::detail::extract_float_vector(serialized, expected_dim),
exceptions::invalid_request_exception,
seastar::testing::exception_predicate::message_contains("Invalid vector size")
);
}
}
BOOST_AUTO_TEST_SUITE_END()

View File

@@ -113,23 +113,15 @@ static future<> compare_object_data(const local_gcs_wrapper& env, std::string_vi
BOOST_REQUIRE_EQUAL(read, total);
}
using namespace std::string_literals;
static constexpr auto prefix = "bork/ninja/"s;
// #28398 include a prefix in all names.
static std::string make_name() {
return fmt::format("{}{}", prefix, utils::UUID_gen::get_time_UUID());
}
static future<> test_read_write_helper(const local_gcs_wrapper& env, size_t dest_size, std::optional<size_t> specific_buffer_size = std::nullopt) {
auto& c = env.client();
auto name = make_name();
auto uuid = fmt::format("{}", utils::UUID_gen::get_time_UUID());
std::vector<temporary_buffer<char>> written;
// ensure we remove the object
env.objects_to_delete.emplace_back(name);
co_await create_object_of_size(c, env.bucket, name, dest_size, &written, specific_buffer_size);
co_await compare_object_data(env, name, std::move(written));
env.objects_to_delete.emplace_back(uuid);
co_await create_object_of_size(c, env.bucket, uuid, dest_size, &written, specific_buffer_size);
co_await compare_object_data(env, uuid, std::move(written));
}
BOOST_AUTO_TEST_SUITE(gcs_tests, *seastar::testing::async_fixture<gcs_fixture>())
@@ -155,28 +147,21 @@ SEASTAR_FIXTURE_TEST_CASE(test_gcp_storage_list_objects, local_gcs_wrapper, *che
auto& c = env.client();
std::unordered_map<std::string, uint64_t> names;
for (size_t i = 0; i < 10; ++i) {
auto name = make_name();
auto name = fmt::format("{}", utils::UUID_gen::get_time_UUID());
auto size = tests::random::get_int(size_t(1), size_t(2*1024*1024));
env.objects_to_delete.emplace_back(name);
co_await create_object_of_size(c, env.bucket, name, size);
names.emplace(name, size);
}
utils::gcp::storage::bucket_paging paging;
auto infos = co_await c.list_objects(env.bucket);
size_t n_found = 0;
for (;;) {
auto infos = co_await c.list_objects(env.bucket, "", paging);
for (auto& info : infos) {
auto i = names.find(info.name);
if (i != names.end()) {
BOOST_REQUIRE_EQUAL(info.size, i->second);
++n_found;
}
}
if (infos.empty()) {
break;
for (auto& info : infos) {
auto i = names.find(info.name);
if (i != names.end()) {
BOOST_REQUIRE_EQUAL(info.size, i->second);
++n_found;
}
}
BOOST_REQUIRE_EQUAL(n_found, names.size());
@@ -185,7 +170,7 @@ SEASTAR_FIXTURE_TEST_CASE(test_gcp_storage_list_objects, local_gcs_wrapper, *che
SEASTAR_FIXTURE_TEST_CASE(test_gcp_storage_delete_object, local_gcs_wrapper, *check_gcp_storage_test_enabled()) {
auto& env = *this;
auto& c = env.client();
auto name = make_name();
auto name = fmt::format("{}", utils::UUID_gen::get_time_UUID());
env.objects_to_delete.emplace_back(name);
co_await create_object_of_size(c, env.bucket, name, 128);
{
@@ -205,7 +190,7 @@ SEASTAR_FIXTURE_TEST_CASE(test_gcp_storage_delete_object, local_gcs_wrapper, *ch
SEASTAR_FIXTURE_TEST_CASE(test_gcp_storage_skip_read, local_gcs_wrapper, *check_gcp_storage_test_enabled()) {
auto& env = *this;
auto& c = env.client();
auto name = make_name();
auto name = fmt::format("{}", utils::UUID_gen::get_time_UUID());
std::vector<temporary_buffer<char>> bufs;
constexpr size_t file_size = 12*1024*1024 + 384*7 + 31;
@@ -258,7 +243,7 @@ SEASTAR_FIXTURE_TEST_CASE(test_merge_objects, local_gcs_wrapper, *check_gcp_stor
size_t total = 0;
for (size_t i = 0; i < 32; ++i) {
auto name = make_name();
auto name = fmt::format("{}", utils::UUID_gen::get_time_UUID());
auto size = tests::random::get_int(size_t(1), size_t(2*1024*1024));
env.objects_to_delete.emplace_back(name);
co_await create_object_of_size(c, env.bucket, name, size, &bufs);
@@ -266,7 +251,7 @@ SEASTAR_FIXTURE_TEST_CASE(test_merge_objects, local_gcs_wrapper, *check_gcp_stor
total += size;
}
auto name = make_name();
auto name = fmt::format("{}", utils::UUID_gen::get_time_UUID());
env.objects_to_delete.emplace_back(name);
auto info = co_await c.merge_objects(env.bucket, name, names);

View File

@@ -980,88 +980,3 @@ BOOST_AUTO_TEST_CASE(s3_fqn_manipulation) {
BOOST_REQUIRE_EQUAL(bucket_name, "bucket");
BOOST_REQUIRE_EQUAL(object_name, "prefix1/prefix2/foo.bar");
}
BOOST_AUTO_TEST_CASE(part_size_calculation_test) {
{
BOOST_REQUIRE_EXCEPTION(s3::calc_part_size(490_GiB, 5_MiB), std::runtime_error, [](const std::runtime_error& e) {
return std::string(e.what()).starts_with("too many parts: 100352 > 10000");
});
}
{
auto [parts, size] = s3::calc_part_size(490_GiB, 100_MiB);
BOOST_REQUIRE_EQUAL(size, 100_MiB);
BOOST_REQUIRE(parts == 5018);
}
{
BOOST_REQUIRE_EXCEPTION(s3::calc_part_size(490_GiB, 4_MiB), std::runtime_error, [](const std::runtime_error& e) {
return std::string(e.what()).starts_with("part_size too small: 4194304 is smaller than minimum part size: 5242880");
});
}
{
auto [parts, size] = s3::calc_part_size(50_MiB, 0);
BOOST_REQUIRE_EQUAL(size, 50_MiB);
BOOST_REQUIRE_EQUAL(parts, 1);
}
{
auto [parts, size] = s3::calc_part_size(49_MiB, 0);
BOOST_REQUIRE_EQUAL(size, 50_MiB);
BOOST_REQUIRE_EQUAL(parts, 1);
}
{
auto [parts, size] = s3::calc_part_size(490_GiB, 0);
BOOST_REQUIRE_EQUAL(size, 51_MiB);
BOOST_REQUIRE(parts == 9839);
}
{
auto [parts, size] = s3::calc_part_size(50_MiB * 10000, 0);
BOOST_REQUIRE_EQUAL(size, 50_MiB);
BOOST_REQUIRE_EQUAL(parts, 10000);
}
{
auto [parts, size] = s3::calc_part_size(50_MiB * 10000 + 1, 0);
BOOST_REQUIRE(size > 50_MiB);
BOOST_REQUIRE(parts <= 10000);
}
{
BOOST_REQUIRE_EXCEPTION(s3::calc_part_size(50_TiB, 0), std::runtime_error, [](const std::runtime_error& e) {
return std::string(e.what()).starts_with("object size too large: 54975581388800 is larger than maximum S3 object size: 53687091200000");
});
}
{
BOOST_REQUIRE_EXCEPTION(s3::calc_part_size(1_TiB, 5_GiB + 1), std::runtime_error, [](const std::runtime_error& e) {
return std::string(e.what()).starts_with("part_size too large: 5368709121 is larger than maximum part size: 5368709120");
});
}
{
auto [parts, size] = s3::calc_part_size(5_TiB, 0);
BOOST_REQUIRE_EQUAL(parts, 9987);
BOOST_REQUIRE_EQUAL(size, 525_MiB);
}
{
auto [parts, size] = s3::calc_part_size(5_MiB * 10000, 5_MiB);
BOOST_REQUIRE_EQUAL(size, 5_MiB);
BOOST_REQUIRE_EQUAL(parts, 10000);
}
{
size_t total = 5_MiB * 10001; // 10001 parts at 5 MiB
BOOST_REQUIRE_EXCEPTION(
s3::calc_part_size(total, 5_MiB), std::runtime_error, [](auto& e) { return std::string(e.what()).starts_with("too many parts: 10001 > 10000"); });
}
{
size_t total = 500_GiB + 123; // odd size to force non-MiB alignment
auto [parts, size] = s3::calc_part_size(total, 0);
BOOST_REQUIRE(size % 1_MiB == 0); // aligned
BOOST_REQUIRE(parts <= 10000);
}
{
auto [parts, size] = s3::calc_part_size(6_MiB, 0);
BOOST_REQUIRE_EQUAL(size, 50_MiB);
BOOST_REQUIRE_EQUAL(parts, 1);
}
{
auto [parts, size] = s3::calc_part_size(100_MiB, 200_MiB);
BOOST_REQUIRE_EQUAL(parts, 1);
BOOST_REQUIRE_EQUAL(size, 200_MiB);
}
}

View File

@@ -8,8 +8,6 @@ from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster, NoHostAvailable
from cassandra import Unauthorized
from cassandra.connection import UnixSocketEndPoint
from cassandra.policies import WhiteListRoundRobinPolicy
from test.cluster.conftest import cluster_con
from test.pylib.manager_client import ManagerClient
@@ -60,7 +58,7 @@ async def test_maintenance_socket(manager: ManagerClient):
else:
pytest.fail("User 'john' has no permissions to access ks2.t1")
maintenance_cluster = cluster_con([UnixSocketEndPoint(socket)], load_balancing_policy=WhiteListRoundRobinPolicy([UnixSocketEndPoint(socket)]))
maintenance_cluster = cluster_con([UnixSocketEndPoint(socket)])
maintenance_session = maintenance_cluster.connect()
# check that the maintenance session has superuser permissions

View File

@@ -262,17 +262,14 @@ async def manager(request: pytest.FixtureRequest,
# Check if the test has the check_nodes_for_errors marker
found_errors = await manager_client.check_all_errors(check_all_errors=(request.node.get_closest_marker("check_nodes_for_errors") is not None))
failed = failed or found_errors
failed_test_dir_path = None
if failed or found_errors:
if failed:
# Save scylladb logs for failed tests in a separate directory and copy XML report to the same directory to have
# all related logs in one dir.
# Then add property to the XML report with the path to the directory, so it can be visible in Jenkins
failed_test_dir_path = testpy_test.suite.log_dir / "failed_test" / test_case_name.translate(
str.maketrans('[]', '()'))
failed_test_dir_path = testpy_test.suite.log_dir / "failed_test" / test_case_name.translate(str.maketrans('[]', '()'))
failed_test_dir_path.mkdir(parents=True, exist_ok=True)
if failed:
await manager_client.gather_related_logs(
failed_test_dir_path,
{'pytest.log': test_log, 'test_py.log': test_py_log_test}
@@ -288,9 +285,7 @@ async def manager(request: pytest.FixtureRequest,
cluster_status = await manager_client.after_test(test_case_name, not failed)
await manager_client.stop() # Stop client session and close driver after each test
if cluster_status["server_broken"] and not failed:
failed = True
if cluster_status["server_broken"]:
pytest.fail(
f"test case {test_case_name} left unfinished tasks on Scylla server. Server marked as broken,"
f" server_broken_reason: {cluster_status["message"]}"
@@ -329,8 +324,7 @@ async def manager(request: pytest.FixtureRequest,
with open(failed_test_dir_path / "found_errors.txt", "w") as f:
f.write("\n".join(full_message))
if not failed:
pytest.fail(f"\n{'\n'.join(full_message)}")
pytest.fail(f"\n{'\n'.join(full_message)}")
# "cql" fixture: set up client object for communicating with the CQL API.
# Since connection is managed by manager just return that object

View File

@@ -36,6 +36,7 @@ run_in_release:
run_in_dev:
- test_raft_ignore_nodes
- test_group0_schema_versioning
- test_different_group0_ids
- test_zero_token_nodes_no_replication
- test_not_enough_token_owners
- test_replace_alive_node

View File

@@ -6,30 +6,53 @@
from test.pylib.manager_client import ManagerClient
import asyncio
import pytest
from test.pylib.util import wait_for_first_completed
@pytest.mark.asyncio
@pytest.mark.xfail(reason="gossiper topology mode is no longer supported, need to rewrite the test using raft topology")
async def test_different_group0_ids(manager: ManagerClient):
"""
The test starts two single-node clusters (with different group0_ids). Node B (the
node from the second cluster) is restarted with seeds containing node A (the node
from the first cluster), and thus it tries to gossip node A. The test checks that
node A rejects gossip_digest_syn.
The reproducer for #14448.
Note: this test relies on the fact that the only node in a single-node cluster
always gossips with its seeds. This can be considered a bug, although a mild one.
If we ever fix it, this test can be rewritten by starting a two-node cluster and
recreating group0 on one of the nodes via the recovery procedure.
The test starts two nodes with different group0_ids. The second node
is restarted and tries to join the cluster consisting of the first node.
gossip_digest_syn message should be rejected by the first node, so
the second node will not be able to join the cluster.
This test uses repair-based node operations to make this test easier.
If the second node successfully joins the cluster, their tokens metadata
will be merged and the repair service will allow to decommission the second node.
If not - decommissioning the second node will fail with an exception
"zero replica after the removal" thrown by the repair service.
"""
scylla_a = await manager.server_add()
scylla_b = await manager.server_add(start=False)
# Consistent topology changes are disabled to use repair based node operations.
cfg = {'force_gossip_topology_changes': True, 'tablets_mode_for_new_keyspaces': 'disabled'}
scylla_a = await manager.server_add(config = cfg)
scylla_b = await manager.server_add(start=False, config = cfg)
await manager.server_start(scylla_b.server_id, seeds=[scylla_b.ip_addr])
id_b = await manager.get_host_id(scylla_b.server_id)
await manager.server_stop(scylla_b.server_id)
await manager.server_start(scylla_b.server_id, seeds=[scylla_a.ip_addr])
await manager.server_start(scylla_b.server_id, seeds=[scylla_a.ip_addr, scylla_b.ip_addr])
log_file_a = await manager.server_open_log(scylla_a.server_id)
await log_file_a.wait_for(f'Group0Id mismatch from {id_b}', timeout=30)
log_file_b = await manager.server_open_log(scylla_b.server_id)
# Wait for a gossip round to finish
await wait_for_first_completed([
log_file_b.wait_for(f'InetAddress {scylla_a.ip_addr} is now UP'), # The second node joins the cluster
log_file_a.wait_for(f'Group0Id mismatch') # The first node discards gossip from the second node
])
# Check if decommissioning the second node fails.
# Repair service throws a runtime exception "zero replica after the removal"
# when it tries to remove the only one node from the cluster.
# If it is not thrown, it means that the second node successfully send a gossip
# to the first node and they merged their tokens metadata.
with pytest.raises(Exception, match='zero replica after the removal'):
await manager.decommission_node(scylla_b.server_id)

View File

@@ -1,140 +0,0 @@
#
# Copyright (C) 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
"""
Test for error injection event stream functionality.
This test demonstrates the new SSE-based error injection event system
that eliminates the need for log parsing in tests.
"""
import asyncio
import logging
import pytest
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import injection_event_stream
logger = logging.getLogger(__name__)
@pytest.mark.asyncio
@pytest.mark.skip_mode('release', 'error injections are not supported in release mode')
async def test_injection_event_stream_basic(manager: ManagerClient):
"""
Test basic error injection event stream functionality.
This test verifies that:
1. We can connect to the SSE event stream
2. Events are received when injections are triggered
3. We can wait for specific injections without log parsing
"""
servers = await manager.servers_add(1)
server_ip = servers[0].ip_addr
# Connect to the injection event stream
async with injection_event_stream(server_ip) as event_stream:
logger.info("Connected to injection event stream")
# Enable a simple injection
test_injection_name = "test_injection_event_basic"
await manager.api.enable_injection(server_ip, test_injection_name, one_shot=True)
# Trigger the injection by calling message_injection
# In real tests, the injection would be triggered by actual code execution
await manager.api.message_injection(server_ip, test_injection_name)
# Wait for the injection event (no log parsing needed!)
try:
event = await event_stream.wait_for_injection(test_injection_name, timeout=10.0)
logger.info(f"Received injection event: {event}")
# Verify event structure
assert event['injection'] == test_injection_name
assert 'type' in event
assert 'shard' in event
logger.info(f"✓ Injection triggered on shard {event['shard']} with type {event['type']}")
except asyncio.TimeoutError:
pytest.fail(f"Injection event for '{test_injection_name}' not received within timeout")
@pytest.mark.asyncio
@pytest.mark.skip_mode('release', 'error injections are not supported in release mode')
async def test_injection_event_stream_multiple_injections(manager: ManagerClient):
"""
Test that we can track multiple injections via the event stream.
"""
servers = await manager.servers_add(1)
server_ip = servers[0].ip_addr
async with injection_event_stream(server_ip) as event_stream:
logger.info("Connected to injection event stream")
# Enable multiple injections
injection_names = [
"test_injection_1",
"test_injection_2",
"test_injection_3",
]
for name in injection_names:
await manager.api.enable_injection(server_ip, name, one_shot=False)
# Trigger injections in sequence
for name in injection_names:
await manager.api.message_injection(server_ip, name)
# Wait for each injection event
event = await event_stream.wait_for_injection(name, timeout=10.0)
logger.info(f"✓ Received event for {name}: type={event['type']}, shard={event['shard']}")
# Cleanup
for name in injection_names:
await manager.api.disable_injection(server_ip, name)
logger.info("✓ All injection events received successfully")
@pytest.mark.asyncio
@pytest.mark.skip_mode('release', 'error injections are not supported in release mode')
async def test_injection_event_vs_log_parsing_comparison(manager: ManagerClient):
"""
Demonstration test comparing the old log parsing approach vs new event stream approach.
This shows how the new SSE event stream eliminates the need for log parsing,
making tests faster and more reliable.
"""
servers = await manager.servers_add(1)
server = servers[0]
injection_name = "test_comparison_injection"
# OLD APPROACH: Log parsing (commented to show the pattern)
# -----------------------------------------------------
# log = await manager.server_open_log(server.server_id)
# mark = await log.mark()
# await manager.api.enable_injection(server.ip_addr, injection_name, one_shot=True)
# # ... trigger some operation that hits the injection ...
# mark, _ = await log.wait_for(f'{injection_name}: waiting', from_mark=mark)
# # Now we know the injection was hit by parsing logs
# -----------------------------------------------------
# NEW APPROACH: Event stream (no log parsing!)
# -----------------------------------------------------
async with injection_event_stream(server.ip_addr) as event_stream:
logger.info("✓ Connected to injection event stream (no log parsing needed)")
# Enable and trigger injection
await manager.api.enable_injection(server.ip_addr, injection_name, one_shot=True)
await manager.api.message_injection(server.ip_addr, injection_name)
# Wait for injection event - fast and reliable!
event = await event_stream.wait_for_injection(injection_name, timeout=10.0)
logger.info(f"✓ Injection detected via event stream: {event}")
# No log parsing, no regex matching, no waiting for log flushes
# Just direct event notification from the injection point
# -----------------------------------------------------
logger.info("✓ New event stream approach is faster and more reliable than log parsing!")

View File

@@ -7,14 +7,15 @@ import asyncio
import pytest
import time
import logging
import requests
import re
from cassandra.cluster import NoHostAvailable # type: ignore
from cassandra.cluster import ConnectionException, NoHostAvailable # type: ignore
from cassandra.query import SimpleStatement, ConsistencyLevel
from test.pylib.internal_types import IPAddress
from test.pylib.internal_types import ServerInfo
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import ScyllaMetricsClient, TCPRESTClient, inject_error
from test.pylib.rest_client import inject_error
from test.pylib.tablets import get_tablet_replicas
from test.pylib.scylla_cluster import ReplaceConfig
from test.pylib.util import wait_for
@@ -24,21 +25,26 @@ from test.cluster.util import get_topology_coordinator, find_server_by_host_id,
logger = logging.getLogger(__name__)
async def get_hint_metrics(client: ScyllaMetricsClient, server_ip: IPAddress, metric_name: str):
metrics = await client.query(server_ip)
return metrics.get(f"scylla_hints_manager_{metric_name}")
def get_hint_manager_metric(server: ServerInfo, metric_name: str) -> int:
result = 0
metrics = requests.get(f"http://{server.ip_addr}:9180/metrics").text
pattern = re.compile(f"^scylla_hints_manager_{metric_name}")
for metric in metrics.split('\n'):
if pattern.match(metric) is not None:
result += int(float(metric.split()[1]))
return result
async def create_sync_point(client: TCPRESTClient, server_ip: IPAddress) -> str:
response = await client.post_json("/hinted_handoff/sync_point", host=server_ip, port=10_000)
return response
# Creates a sync point for ALL hosts.
def create_sync_point(node: ServerInfo) -> str:
return requests.post(f"http://{node.ip_addr}:10000/hinted_handoff/sync_point/").json()
async def await_sync_point(client: TCPRESTClient, server_ip: IPAddress, sync_point: str, timeout: int) -> bool:
def await_sync_point(node: ServerInfo, sync_point: str, timeout: int) -> bool:
params = {
"id": sync_point,
"timeout": str(timeout)
}
response = await client.get_json("/hinted_handoff/sync_point", host=server_ip, port=10_000, params=params)
response = requests.get(f"http://{node.ip_addr}:10000/hinted_handoff/sync_point", params=params).json()
match response:
case "IN_PROGRESS":
return False
@@ -60,7 +66,10 @@ async def test_write_cl_any_to_dead_node_generates_hints(manager: ManagerClient)
await manager.server_stop_gracefully(servers[1].server_id)
hints_before = await get_hint_metrics(manager.metrics, servers[0].ip_addr, "written")
def get_hints_written_count(server):
return get_hint_manager_metric(server, "written")
hints_before = get_hints_written_count(servers[0])
# Some of the inserts will be targeted to the dead node.
# The coordinator doesn't have live targets to send the write to, but it should write a hint.
@@ -68,7 +77,7 @@ async def test_write_cl_any_to_dead_node_generates_hints(manager: ManagerClient)
await cql.run_async(SimpleStatement(f"INSERT INTO {table} (pk, v) VALUES ({i}, {i+1})", consistency_level=ConsistencyLevel.ANY))
# Verify hints are written
hints_after = await get_hint_metrics(manager.metrics, servers[0].ip_addr, "written")
hints_after = get_hints_written_count(servers[0])
assert hints_after > hints_before
# For dropping the keyspace
@@ -134,29 +143,24 @@ async def test_sync_point(manager: ManagerClient):
# Mutations need to be applied to hinted handoff's commitlog before we create the sync point.
# Otherwise, the sync point will correspond to no hints at all.
async def check_written_hints(min_count: int) -> bool:
errors = await get_hint_metrics(manager.metrics, node1.ip_addr, "errors")
assert errors == 0, "Writing hints to disk failed"
hints = await get_hint_metrics(manager.metrics, node1.ip_addr, "written")
if hints >= min_count:
return True
return None
# We need to wrap the function in an async function to make `wait_for` be able to use it below.
async def check_no_hints_in_progress_node1() -> bool:
return get_hint_manager_metric(node1, "size_of_hints_in_progress") == 0
deadline = time.time() + 30
await wait_for(lambda: check_written_hints(2 * mutation_count), deadline)
await wait_for(check_no_hints_in_progress_node1, deadline)
sync_point1 = await create_sync_point(manager.api.client, node1.ip_addr)
sync_point1 = create_sync_point(node1)
await manager.server_start(node2.server_id)
await manager.server_sees_other_server(node1.ip_addr, node2.ip_addr)
assert not (await await_sync_point(manager.api.client, node1.ip_addr, sync_point1, 3))
assert not await_sync_point(node1, sync_point1, 30)
await manager.server_start(node3.server_id)
await manager.server_sees_other_server(node1.ip_addr, node3.ip_addr)
assert await await_sync_point(manager.api.client, node1.ip_addr, sync_point1, 30)
assert await_sync_point(node1, sync_point1, 30)
@pytest.mark.asyncio
@@ -202,8 +206,7 @@ async def test_hints_consistency_during_decommission(manager: ManagerClient):
await manager.servers_see_each_other([server1, server2, server3])
# Record the current position of hints so that we can wait for them later
sync_points = await asyncio.gather(*[create_sync_point(manager.api.client, srv.ip_addr) for srv in (server1, server2)])
sync_points = list(sync_points)
sync_points = [create_sync_point(srv) for srv in (server1, server2)]
async with asyncio.TaskGroup() as tg:
coord = await get_topology_coordinator(manager)
@@ -229,8 +232,7 @@ async def test_hints_consistency_during_decommission(manager: ManagerClient):
await manager.api.disable_injection(srv.ip_addr, "hinted_handoff_pause_hint_replay")
logger.info("Wait until hints are replayed from nodes 1 and 2")
await asyncio.gather(*(await_sync_point(manager.api.client, srv.ip_addr, pt, timeout=30)
for srv, pt in zip((server1, server2), sync_points)))
await asyncio.gather(*(asyncio.to_thread(await_sync_point, srv, pt, timeout=30) for srv, pt in zip((server1, server2), sync_points)))
# Unpause streaming and let decommission finish
logger.info("Unpause streaming")
@@ -268,11 +270,11 @@ async def test_hints_consistency_during_replace(manager: ManagerClient):
# Write 100 rows with CL=ANY. Some of the rows will only be stored as hints because of RF=1
for i in range(100):
await cql.run_async(SimpleStatement(f"INSERT INTO {table} (pk, v) VALUES ({i}, {i + 1})", consistency_level=ConsistencyLevel.ANY))
sync_point = await create_sync_point(manager.api.client, servers[0].ip_addr)
sync_point = create_sync_point(servers[0])
await manager.server_add(replace_cfg=ReplaceConfig(replaced_id = servers[2].server_id, reuse_ip_addr = False, use_host_id = True))
assert await await_sync_point(manager.api.client, servers[0].ip_addr, sync_point, 30)
assert await_sync_point(servers[0], sync_point, 30)
# Verify that all rows were recovered by the hint replay
for i in range(100):
assert list(await cql.run_async(f"SELECT v FROM {table} WHERE pk = {i}")) == [(i + 1,)]
@@ -297,12 +299,16 @@ async def test_draining_hints(manager: ManagerClient):
for i in range(1000):
await cql.run_async(SimpleStatement(f"INSERT INTO ks.t (pk, v) VALUES ({i}, {i + 1})", consistency_level=ConsistencyLevel.ANY))
sync_point = await create_sync_point(manager.api.client, s1.ip_addr)
sync_point = create_sync_point(s1)
await manager.server_start(s2.server_id)
async def wait():
assert await_sync_point(s1, sync_point, 60)
async with asyncio.TaskGroup() as tg:
_ = tg.create_task(manager.decommission_node(s1.server_id, timeout=60))
_ = tg.create_task(await_sync_point(manager.api.client, s1.ip_addr, sync_point, 60))
_ = tg.create_task(wait())
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
@@ -328,7 +334,7 @@ async def test_canceling_hint_draining(manager: ManagerClient):
for i in range(1000):
await cql.run_async(SimpleStatement(f"INSERT INTO ks.t (pk, v) VALUES ({i}, {i + 1})", consistency_level=ConsistencyLevel.ANY))
sync_point = await create_sync_point(manager.api.client, s1.ip_addr)
sync_point = create_sync_point(s1)
await manager.api.enable_injection(s1.ip_addr, "hinted_handoff_pause_hint_replay", False, {})
await manager.remove_node(s1.server_id, s2.server_id)
@@ -346,7 +352,7 @@ async def test_canceling_hint_draining(manager: ManagerClient):
await s1_log.wait_for(f"Draining starts for {host_id2}", from_mark=s1_mark)
# Make sure draining finishes successfully.
assert await await_sync_point(manager.api.client, s1.ip_addr, sync_point, 60)
assert await_sync_point(s1, sync_point, 60)
await s1_log.wait_for(f"Removed hint directory for {host_id2}")
@pytest.mark.asyncio
@@ -385,7 +391,7 @@ async def test_hint_to_pending(manager: ManagerClient):
await manager.api.enable_injection(servers[0].ip_addr, "hinted_handoff_pause_hint_replay", False)
await manager.server_start(servers[1].server_id)
sync_point = await create_sync_point(manager.api.client, servers[0].ip_addr)
sync_point = create_sync_point(servers[0])
await manager.api.enable_injection(servers[0].ip_addr, "pause_after_streaming_tablet", False)
tablet_migration = asyncio.create_task(manager.api.move_tablet(servers[0].ip_addr, ks, "t", host_ids[1], 0, host_ids[0], 0, 0))
@@ -397,7 +403,7 @@ async def test_hint_to_pending(manager: ManagerClient):
await wait_for(migration_reached_streaming, time.time() + 60)
await manager.api.disable_injection(servers[0].ip_addr, "hinted_handoff_pause_hint_replay")
assert await await_sync_point(manager.api.client, servers[0].ip_addr, sync_point, 30)
assert await_sync_point(servers[0], sync_point, 30)
await manager.api.message_injection(servers[0].ip_addr, "pause_after_streaming_tablet")
done, pending = await asyncio.wait([tablet_migration])

View File

@@ -3,7 +3,7 @@
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
import asyncio
import logging
import time
@@ -53,14 +53,7 @@ async def test_no_removed_node_event_on_ip_change(manager: ManagerClient, caplog
logger.info("waiting for cql and hosts")
await wait_for_cql_and_get_hosts(test_cql, servers, time.time() + 30)
# This for loop is done to avoid the race condition when we're checking the logs before a message is arrived.
# Locally issue was not reproducible, but on CI it was.
log_output = caplog.text
for i in range(5):
try:
assert f"'change_type': 'NEW_NODE', 'address': ('{s1_new_ip}'" in log_output
break
except AssertionError:
await asyncio.sleep(i)
log_output = caplog.text
log_output: str = caplog.text
assert f"'change_type': 'NEW_NODE', 'address': ('{s1_new_ip}'" in log_output
assert f"'change_type': 'REMOVED_NODE', 'address': ('{s1_old_ip}'" not in log_output

View File

@@ -4,7 +4,6 @@
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import read_barrier
from test.cluster.util import new_test_keyspace
from collections import defaultdict
import pytest
@@ -55,9 +54,6 @@ async def test_balance_empty_tablets(manager: ManagerClient):
await manager.api.quiesce_topology(servers[0].ip_addr)
# Ensure all nodes see the same data in system.tablets
await asyncio.gather(*[read_barrier(manager.api, s.ip_addr) for s in servers])
replicas_per_node = defaultdict(int)
tablets_per_shard = {}
for row in await cql.run_async('SELECT * FROM system.tablets'):

View File

@@ -53,9 +53,6 @@ async def test_autoretrain_dict(manager: ManagerClient):
n_blobs = 1024
uncompressed_size = blob_size * n_blobs * rf
# Start with compressor without a dictionary
cfg = { "sstable_compression_user_table_options": "ZstdCompressor" }
logger.info("Bootstrapping cluster")
servers = await manager.servers_add(2, cmdline=[
'--logger-log-level=storage_service=debug',
@@ -64,7 +61,7 @@ async def test_autoretrain_dict(manager: ManagerClient):
'--sstable-compression-dictionaries-retrain-period-in-seconds=1',
'--sstable-compression-dictionaries-autotrainer-tick-period-in-seconds=1',
f'--sstable-compression-dictionaries-min-training-dataset-bytes={int(uncompressed_size/2)}',
], auto_rack_dc="dc1", config=cfg)
], auto_rack_dc="dc1")
logger.info("Creating table")
cql = manager.get_cql()
@@ -79,9 +76,9 @@ async def test_autoretrain_dict(manager: ManagerClient):
await asyncio.gather(*[manager.api.disable_autocompaction(s.ip_addr, ks_name, cf_name) for s in servers])
async def repopulate():
blob = random.randbytes(blob_size)
blob = random.randbytes(blob_size);
insert = cql.prepare("INSERT INTO test.test (pk, c) VALUES (?, ?);")
insert.consistency_level = ConsistencyLevel.ALL
insert.consistency_level = ConsistencyLevel.ALL;
for pks in itertools.batched(range(n_blobs), n=100):
await asyncio.gather(*[
cql.run_async(insert, [k, blob])

View File

@@ -467,9 +467,6 @@ async def test_restart_leaving_replica_during_cleanup(manager: ManagerClient, mi
# Restart the leaving replica (src_server)
await manager.server_restart(src_server.server_id)
cql = await reconnect_driver(manager)
await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60)
await asyncio.gather(*[manager.api.disable_injection(s.ip_addr, injection) for s in servers])
await manager.enable_tablet_balancing()
@@ -490,6 +487,9 @@ async def test_restart_leaving_replica_during_cleanup(manager: ManagerClient, mi
return True
await wait_for(tablets_merged, time.time() + 60)
# Workaround for https://github.com/scylladb/scylladb/issues/21779. We don't want the keyspace drop at the end
# of new_test_keyspace to fail because of concurrent tablet migrations.
await manager.disable_tablet_balancing()
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')

View File

@@ -304,7 +304,7 @@ async def test_remove_node_violating_rf_rack_with_rack_list(manager: ManagerClie
"""
Test removing a node when it would violate RF-rack constraints with explicit rack list.
Creates a cluster with 5 racks (r1, r2, r3, r4, r5) and a keyspace that explicitly
Creates a cluster with 4 racks (r1, r2, r3, r4) and a keyspace that explicitly
specifies RF as a list of racks ['r1', 'r2', 'r4'].
Tests that:
@@ -323,12 +323,11 @@ async def test_remove_node_violating_rf_rack_with_rack_list(manager: ManagerClie
elif op == "decommission":
await manager.decommission_node(server_id, expected_error=expected_error)
servers = await manager.servers_add(5, config=cfg, cmdline=cmdline, property_file=[
servers = await manager.servers_add(4, config=cfg, cmdline=cmdline, property_file=[
{"dc": "dc1", "rack": "r1"},
{"dc": "dc1", "rack": "r2"},
{"dc": "dc1", "rack": "r3"},
{"dc": "dc1", "rack": "r4"},
{"dc": "dc1", "rack": "r5"},
])
cql = manager.get_cql()

View File

@@ -530,6 +530,8 @@ def testDropAndReaddDroppedCollection(cql, test_keyspace):
execute(cql, table, "alter table %s drop v")
execute(cql, table, "alter table %s add v set<text>")
# FIXME: this test is 20 times slower than the rest (run pytest with "--durations=5"
# to see the 5 slowest tests). Is it checking anything worth this slowness??
def testMapWithLargePartition(cql, test_keyspace):
seed = time.time()
print(f"Seed {seed}")
@@ -538,7 +540,7 @@ def testMapWithLargePartition(cql, test_keyspace):
with create_table(cql, test_keyspace, "(userid text PRIMARY KEY, properties map<int, text>) with compression = {}") as table:
numKeys = 200
for i in range(numKeys):
s = 'x'*length
s = ''.join(random.choice(string.ascii_uppercase) for x in range(length))
execute(cql, table,"UPDATE %s SET properties[?] = ? WHERE userid = 'user'", i, s)
flush(cql, table)

View File

@@ -40,20 +40,14 @@ def simple_no_clustering_table(cql, keyspace):
cql.execute(schema)
# Ensure at least 3 live rows for tests that depend on it
live_rows_needed = 3
for pk in range(0, 10):
# For the first 3 rows, always insert; for the rest, use randomness
if pk < live_rows_needed:
cql.execute(f"INSERT INTO {keyspace}.{table} (pk, v) VALUES ({pk}, 0)")
x = random.randrange(0, 4)
if x == 0:
# partition tombstone
cql.execute(f"DELETE FROM {keyspace}.{table} WHERE pk = {pk}")
else:
x = random.randrange(0, 4)
if x == 0:
# partition tombstone
cql.execute(f"DELETE FROM {keyspace}.{table} WHERE pk = {pk}")
else:
# live row
cql.execute(f"INSERT INTO {keyspace}.{table} (pk, v) VALUES ({pk}, 0)")
# live row
cql.execute(f"INSERT INTO {keyspace}.{table} (pk, v) VALUES ({pk}, 0)")
if pk == 5:
nodetool.flush(cql, f"{keyspace}.{table}")

View File

@@ -15,7 +15,6 @@
#include "db/view/view_building_worker.hh"
#include "replica/database_fwd.hh"
#include "test/lib/cql_test_env.hh"
#include "test/lib/test_utils.hh"
#include "cdc/generation_service.hh"
#include "cql3/functions/functions.hh"
#include "cql3/query_processor.hh"
@@ -83,6 +82,7 @@
#include "utils/disk_space_monitor.hh"
#include <sys/time.h>
#include <sys/resource.h>
using namespace std::chrono_literals;
@@ -222,10 +222,26 @@ private:
}
return ::make_shared<service::query_state>(_core_local.local().client_state, empty_service_permit());
}
static void adjust_rlimit() {
// Tests should use 1024 file descriptors, but don't punish them
// with weird behavior if they do.
//
// Since this more of a courtesy, don't make the situation worse if
// getrlimit/setrlimit fail for some reason.
struct rlimit lim;
int r = getrlimit(RLIMIT_NOFILE, &lim);
if (r == -1) {
return;
}
if (lim.rlim_cur < lim.rlim_max) {
lim.rlim_cur = lim.rlim_max;
setrlimit(RLIMIT_NOFILE, &lim);
}
}
public:
single_node_cql_env()
{
tests::adjust_rlimit();
adjust_rlimit();
}
virtual future<::shared_ptr<cql_transport::messages::result_message>> execute_cql(std::string_view text) override {
@@ -861,12 +877,16 @@ private:
std::set<gms::inet_address> seeds;
auto seed_provider = db::config::seed_provider_type();
if (seed_provider.parameters.contains("seeds")) {
for (const auto& seed : utils::split_comma_separated_list(seed_provider.parameters.at("seeds"))) {
seeds.emplace(seed);
size_t begin = 0;
size_t next = 0;
sstring seeds_str = seed_provider.parameters.find("seeds")->second;
while (begin < seeds_str.length() && begin != (next=seeds_str.find(",",begin))) {
seeds.emplace(gms::inet_address(seeds_str.substr(begin,next-begin)));
begin = next+1;
}
}
if (seeds.empty()) {
seeds.emplace("127.0.0.1");
seeds.emplace(gms::inet_address("127.0.0.1"));
}
gms::gossip_config gcfg;

View File

@@ -350,7 +350,6 @@ public:
};
future<> test_env::do_with_async(noncopyable_function<void (test_env&)> func, test_env_config cfg) {
tests::adjust_rlimit();
if (!cfg.storage.is_local_type()) {
auto db_cfg = make_shared<db::config>();
db_cfg->experimental_features({db::experimental_features_t::feature::KEYSPACE_STORAGE_OPTIONS});

View File

@@ -17,7 +17,6 @@
#include "replica/database.hh"
#include "seastarx.hh"
#include <random>
#include <sys/resource.h>
namespace tests {
@@ -141,23 +140,6 @@ sstring make_random_numeric_string(size_t size) {
namespace tests {
void adjust_rlimit() {
// Tests should use 1024 file descriptors, but don't punish them
// with weird behavior if they do.
//
// Since this more of a courtesy, don't make the situation worse if
// getrlimit/setrlimit fail for some reason.
struct rlimit lim;
int r = getrlimit(RLIMIT_NOFILE, &lim);
if (r == -1) {
return;
}
if (lim.rlim_cur < lim.rlim_max) {
lim.rlim_cur = lim.rlim_max;
setrlimit(RLIMIT_NOFILE, &lim);
}
}
future<bool> compare_files(std::string fa, std::string fb) {
auto cont_a = co_await util::read_entire_file_contiguous(fa);
auto cont_b = co_await util::read_entire_file_contiguous(fb);

View File

@@ -114,7 +114,6 @@ inline auto check_run_test_decorator(std::string_view test_var, bool def = false
}
extern boost::test_tools::assertion_result has_scylla_test_env(boost::unit_test::test_unit_id);
void adjust_rlimit();
future<bool> compare_files(std::string fa, std::string fb);
future<> touch_file(std::string name);

View File

@@ -113,7 +113,7 @@ future<> apply_resize_plan(token_metadata& tm, const migration_plan& plan) {
// Reflects the plan in a given token metadata as if the migrations were fully executed.
static
future<> apply_plan(token_metadata& tm, const migration_plan& plan, locator::load_stats& load_stats) {
future<> apply_plan(token_metadata& tm, const migration_plan& plan) {
for (auto&& mig : plan.migrations()) {
co_await tm.tablets().mutate_tablet_map_async(mig.tablet.table, [&mig] (tablet_map& tmap) {
auto tinfo = tmap.get_tablet_info(mig.tablet.tablet);
@@ -121,18 +121,6 @@ future<> apply_plan(token_metadata& tm, const migration_plan& plan, locator::loa
tmap.set_tablet(mig.tablet.tablet, tinfo);
return make_ready_future();
});
// Move tablet size in load_stats to account for the migration
if (mig.src.host != mig.dst.host) {
auto& tmap = tm.tablets().get_tablet_map(mig.tablet.table);
const dht::token_range trange = tmap.get_token_range(mig.tablet.tablet);
lw_shared_ptr<locator::load_stats> new_stats = load_stats.migrate_tablet_size(mig.src.host, mig.dst.host, mig.tablet, trange);
if (new_stats) {
load_stats = std::move(*new_stats);
} else {
throw std::runtime_error(format("Unable to migrate tablet size in load_stats for migration: {}", mig));
}
}
}
co_await apply_resize_plan(tm, plan);
}
@@ -153,7 +141,7 @@ struct rebalance_stats {
};
static
rebalance_stats rebalance_tablets(cql_test_env& e, locator::load_stats& load_stats, std::unordered_set<host_id> skiplist = {}) {
rebalance_stats rebalance_tablets(cql_test_env& e, locator::load_stats_ptr load_stats = {}, std::unordered_set<host_id> skiplist = {}) {
rebalance_stats stats;
abort_source as;
@@ -167,10 +155,9 @@ rebalance_stats rebalance_tablets(cql_test_env& e, locator::load_stats& load_sta
for (size_t i = 0; i < max_iterations; ++i) {
auto prev_lb_stats = *talloc.stats().for_dc(dc);
auto load_stats_p = make_lw_shared<locator::load_stats>(load_stats);
auto start_time = std::chrono::steady_clock::now();
auto plan = talloc.balance_tablets(stm.get(), nullptr, nullptr, load_stats_p, skiplist).get();
auto plan = talloc.balance_tablets(stm.get(), nullptr, nullptr, load_stats, skiplist).get();
auto end_time = std::chrono::steady_clock::now();
auto lb_stats = *talloc.stats().for_dc(dc) - prev_lb_stats;
@@ -204,7 +191,7 @@ rebalance_stats rebalance_tablets(cql_test_env& e, locator::load_stats& load_sta
return stats;
}
stm.mutate_token_metadata([&] (token_metadata& tm) {
return apply_plan(tm, plan, load_stats);
return apply_plan(tm, plan);
}).get();
}
throw std::runtime_error("rebalance_tablets(): convergence not reached within limit");
@@ -220,7 +207,6 @@ struct params {
int shards;
int scale1 = 1;
int scale2 = 1;
double tablet_size_deviation_factor = 0.5;
};
struct table_balance {
@@ -246,7 +232,7 @@ template<>
struct fmt::formatter<table_balance> : fmt::formatter<string_view> {
template <typename FormatContext>
auto format(const table_balance& b, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{{shard={} (best={}), node={}}}",
return fmt::format_to(ctx.out(), "{{shard={:.2f} (best={:.2f}), node={:.2f}}}",
b.shard_overcommit, b.best_shard_overcommit, b.node_overcommit);
}
};
@@ -265,53 +251,14 @@ struct fmt::formatter<params> : fmt::formatter<string_view> {
auto format(const params& p, FormatContext& ctx) const {
auto tablets1_per_shard = double(p.tablets1.value_or(0)) * p.rf1 / (p.nodes * p.shards);
auto tablets2_per_shard = double(p.tablets2.value_or(0)) * p.rf2 / (p.nodes * p.shards);
return fmt::format_to(ctx.out(), "{{iterations={}, nodes={}, tablets1={} ({:0.1f}/sh), tablets2={} ({:0.1f}/sh), rf1={}, rf2={}, shards={}, tablet_size_deviation_factor={}}}",
return fmt::format_to(ctx.out(), "{{iterations={}, nodes={}, tablets1={} ({:0.1f}/sh), tablets2={} ({:0.1f}/sh), rf1={}, rf2={}, shards={}}}",
p.iterations, p.nodes,
p.tablets1.value_or(0), tablets1_per_shard,
p.tablets2.value_or(0), tablets2_per_shard,
p.rf1, p.rf2, p.shards, p.tablet_size_deviation_factor);
p.rf1, p.rf2, p.shards);
}
};
class tablet_size_generator {
std::default_random_engine _rnd_engine{std::random_device{}()};
std::normal_distribution<> _dist;
public:
explicit tablet_size_generator(double deviation_factor)
: _dist(default_target_tablet_size, default_target_tablet_size * deviation_factor) {
}
uint64_t generate() {
// We can't have a negative tablet size, which is why we need to minimize it to 0 (with std::max()).
// One consequence of this is that the average generated tablet size will actually
// be larger than default_target_tablet_size.
// This will be especially pronounced as deviation_factor gets larger. For instance:
//
// deviation_factor | avg tablet size
// -----------------+----------------------------------------
// 1 | default_target_tablet_size * 1.08
// 1.5 | default_target_tablet_size * 1.22
// 2 | default_target_tablet_size * 1.39
// 3 | default_target_tablet_size * 1.76
return std::max(0.0, _dist(_rnd_engine));
}
};
void generate_tablet_sizes(double tablet_size_deviation_factor, locator::load_stats& stats, locator::shared_token_metadata& stm) {
tablet_size_generator tsg(tablet_size_deviation_factor);
for (auto&& [table, tmap] : stm.get()->tablets().all_tables_ungrouped()) {
tmap->for_each_tablet([&] (tablet_id tid, const tablet_info& ti) -> future<> {
for (const auto& replica : ti.replicas) {
const uint64_t tablet_size = tsg.generate();
locator::range_based_tablet_id rb_tid {table, tmap->get_token_range(tid)};
stats.tablet_stats[replica.host].tablet_sizes[rb_tid.table][rb_tid.range] = tablet_size;
testlog.trace("Generated tablet size {} for {}:{}", tablet_size, table, tid);
}
return make_ready_future<>();
}).get();
}
}
future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware) {
auto cfg = tablet_cql_test_config();
results global_res;
@@ -325,7 +272,6 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
const size_t rf2 = p.rf2;
const shard_id shard_count = p.shards;
const int cycles = p.iterations;
const uint64_t shard_capacity = default_target_tablet_size * 100;
struct host_info {
host_id id;
@@ -348,22 +294,19 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
const sstring dc1 = topo.dc();
populate_racks(rf1);
// The rack for which we output stats
sstring test_rack = racks.front().rack;
const size_t rack_count = racks.size();
std::unordered_map<sstring, uint64_t> rack_capacity;
auto add_host = [&] (endpoint_dc_rack dc_rack) {
auto host = topo.add_node(service::node_state::normal, shard_count, dc_rack);
hosts.emplace_back(host, dc_rack);
const uint64_t capacity = shard_capacity * shard_count;
stats.capacity[host] = capacity;
stats.tablet_stats[host].effective_capacity = capacity;
rack_capacity[dc_rack.rack] += capacity;
stats.capacity[host] = default_target_tablet_size * shard_count;
testlog.info("Added new node: {} / {}:{}", host, dc_rack.dc, dc_rack.rack);
};
auto make_stats = [&] {
return make_lw_shared<locator::load_stats>(stats);
};
for (size_t i = 0; i < n_hosts; ++i) {
add_host(racks[i % rack_count]);
}
@@ -372,7 +315,7 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
auto bootstrap = [&] (endpoint_dc_rack dc_rack) {
add_host(std::move(dc_rack));
global_res.stats += rebalance_tablets(e, stats);
global_res.stats += rebalance_tablets(e, make_stats());
};
auto decommission = [&] (host_id host) {
@@ -383,15 +326,13 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
throw std::runtime_error(format("No such host: {}", host));
}
topo.set_node_state(host, service::node_state::decommissioning);
global_res.stats += rebalance_tablets(e, stats);
global_res.stats += rebalance_tablets(e, make_stats());
if (stm.get()->tablets().has_replica_on(host)) {
throw std::runtime_error(format("Host {} still has replicas!", host));
}
topo.set_node_state(host, service::node_state::left);
testlog.info("Node decommissioned: {}", host);
rack_capacity[it->dc_rack.rack] -= stats.capacity.at(host);
hosts.erase(it);
stats.tablet_stats.erase(host);
};
auto ks1 = add_keyspace(e, {{dc1, rf1}}, p.tablets1.value_or(1));
@@ -401,135 +342,49 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
schema_ptr s1 = e.local_db().find_schema(id1);
schema_ptr s2 = e.local_db().find_schema(id2);
generate_tablet_sizes(p.tablet_size_deviation_factor, stats, stm);
// Compute table size per rack, and collect all tablets per rack
std::unordered_map<sstring, std::unordered_map<table_id, uint64_t>> table_sizes_per_rack;
std::unordered_map<sstring, std::unordered_map<table_id, std::vector<uint64_t>>> tablet_sizes_in_rack;
for (auto& [host, tls] : stats.tablet_stats) {
auto host_i = std::ranges::find(hosts, host, &host_info::id);
if (host_i == hosts.end()) {
throw std::runtime_error(format("Host {} not found in hosts", host));
}
auto rack = host_i->dc_rack.rack;
for (auto& [table, ranges] : tls.tablet_sizes) {
for (auto& [trange, tablet_size] : ranges) {
table_sizes_per_rack[rack][table] += tablet_size;
tablet_sizes_in_rack[rack][table].push_back(tablet_size);
}
}
}
// Sort the tablet sizes per rack in descending order
for (auto& [rack, tables] : tablet_sizes_in_rack) {
for (auto& [table, tablets] : tables) {
std::ranges::sort(tablets, std::greater<uint64_t>());
}
}
struct node_used_size {
host_id host;
uint64_t used = 0;
};
// Compute best shard overcommit per table per rack
std::unordered_map<sstring, std::unordered_map<table_id, double>> best_shard_overcommit_per_rack;
auto compute_best_overcommit = [&] () {
auto node_size_compare = [] (const node_used_size& lhs, const node_used_size& rhs) {
return lhs.used > rhs.used;
};
for (auto& all_dc_rack : racks) {
auto rack = all_dc_rack.rack;
// Allocate tablet sizes to nodes
for (auto& [table, tablet_sizes]: tablet_sizes_in_rack.at(rack)) {
load_sketch load(e.shared_token_metadata().local().get(), make_lw_shared<locator::load_stats>(stats));
// Add nodes to load_sketch and to the nodes_used heap
std::vector<node_used_size> nodes_used;
for (const auto& [host_id, host_dc_rack] : hosts) {
if (rack == host_dc_rack.rack) {
load.ensure_node(host_id);
nodes_used.push_back({host_id, 0});
}
}
// Allocate tablets to nodes/shards
for (uint64_t tablet_size : tablet_sizes) {
std::pop_heap(nodes_used.begin(), nodes_used.end(), node_size_compare);
host_id add_to_host = nodes_used.back().host;
nodes_used.back().used += tablet_size;
std::push_heap(nodes_used.begin(), nodes_used.end(), node_size_compare);
// Add to the least loaded shard on the least loaded node
load.next_shard(add_to_host, 1, tablet_size);
}
// Get the best overcommit from all the nodes
min_max_tracker<locator::disk_usage::load_type> load_minmax;
for (const auto& n : nodes_used) {
load_minmax.update(load.get_shard_minmax(n.host));
}
const uint64_t table_size = table_sizes_per_rack.at(rack).at(table);
const double ideal_load = double(table_size) / rack_capacity.at(rack);
const double best_overcommit = load_minmax.max() / ideal_load;
best_shard_overcommit_per_rack[rack][table] = best_overcommit;
}
}
};
auto check_balance = [&] () -> cluster_balance {
cluster_balance res;
testlog.debug("tablet metadata: {}", stm.get()->tablets());
compute_best_overcommit();
auto load_stats_p = make_lw_shared<locator::load_stats>(stats);
int table_index = 0;
for (auto s : {s1, s2}) {
auto table = s->id();
load_sketch load(stm.get(), load_stats_p);
load.populate(std::nullopt, table).get();
load_sketch load(stm.get());
load.populate(std::nullopt, s->id()).get();
min_max_tracker<double> shard_overcommit_minmax;
min_max_tracker<double> node_overcommit_minmax;
auto rack = test_rack;
auto table_size = table_sizes_per_rack.at(rack).at(table);
auto ideal_load = double(table_size) / rack_capacity.at(rack);
min_max_tracker<double> shard_load_minmax;
min_max_tracker<double> node_load_minmax;
for (auto [h, host_dc_rack] : hosts) {
if (host_dc_rack.rack != rack) {
continue;
}
min_max_tracker<uint64_t> shard_load_minmax;
min_max_tracker<uint64_t> node_load_minmax;
uint64_t sum_node_load = 0;
uint64_t shard_count = 0;
for (auto [h, _] : hosts) {
auto minmax = load.get_shard_minmax(h);
auto node_load = load.get_load(h);
auto overcommit = double(minmax.max()) / ideal_load;
testlog.info("Load on host {} for table {}: total={}, min={}, max={}, spread={}, ideal={}, overcommit={}",
h, s->cf_name(), node_load, minmax.min(), minmax.max(), minmax.max() - minmax.min(), ideal_load, overcommit);
node_load_minmax.update(node_load);
auto avg_shard_load = load.get_real_avg_tablet_count(h);
auto overcommit = double(minmax.max()) / avg_shard_load;
shard_load_minmax.update(minmax.max());
shard_count += load.get_shard_count(h);
testlog.info("Load on host {} for table {}: total={}, min={}, max={}, spread={}, avg={:.2f}, overcommit={:.2f}",
h, s->cf_name(), node_load, minmax.min(), minmax.max(), minmax.max() - minmax.min(), avg_shard_load, overcommit);
node_load_minmax.update(node_load);
sum_node_load += node_load;
}
auto shard_overcommit = shard_load_minmax.max() / ideal_load;
auto best_shard_overcommit = best_shard_overcommit_per_rack.at(rack).at(table);
testlog.info("Shard overcommit: {} best: {}", shard_overcommit, best_shard_overcommit);
auto avg_shard_load = double(sum_node_load) / shard_count;
auto shard_overcommit = shard_load_minmax.max() / avg_shard_load;
// Overcommit given the best distribution of tablets given current number of tablets.
auto best_shard_overcommit = div_ceil(sum_node_load, shard_count) / avg_shard_load;
testlog.info("Shard overcommit: {:.2f}, best={:.2f}", shard_overcommit, best_shard_overcommit);
auto node_imbalance = node_load_minmax.max() - node_load_minmax.min();
auto node_overcommit = node_load_minmax.max() / ideal_load;
testlog.info("Node imbalance in min={}, max={}, spread={}, ideal={}, overcommit={}",
node_load_minmax.min(), node_load_minmax.max(), node_imbalance, ideal_load, node_overcommit);
shard_overcommit_minmax.update(shard_overcommit);
node_overcommit_minmax.update(node_overcommit);
auto avg_node_load = double(sum_node_load) / hosts.size();
auto node_overcommit = node_load_minmax.max() / avg_node_load;
testlog.info("Node imbalance: min={}, max={}, spread={}, avg={:.2f}, overcommit={:.2f}",
node_load_minmax.min(), node_load_minmax.max(), node_imbalance, avg_node_load, node_overcommit);
res.tables[table_index++] = {
.shard_overcommit = shard_overcommit_minmax.max(),
.shard_overcommit = shard_overcommit,
.best_shard_overcommit = best_shard_overcommit,
.node_overcommit = node_overcommit_minmax.max(),
.node_overcommit = node_overcommit
};
}
@@ -549,7 +404,7 @@ future<results> test_load_balancing_with_many_tables(params p, bool tablet_aware
check_balance();
rebalance_tablets(e, stats);
rebalance_tablets(e, make_stats());
global_res.init = global_res.worst = check_balance();
@@ -573,7 +428,6 @@ void test_parallel_scaleout(const bpo::variables_map& opts) {
const int nr_racks = opts["racks"].as<int>();
const int initial_nodes = nr_racks * opts["nodes-per-rack"].as<int>();
const int extra_nodes = nr_racks * opts["extra-nodes-per-rack"].as<int>();
const double tablet_size_deviation_factor = opts["tablet-size-deviation-factor"].as<double>();
auto cfg = tablet_cql_test_config();
cfg.db_config->rf_rack_valid_keyspaces(true);
@@ -582,6 +436,10 @@ void test_parallel_scaleout(const bpo::variables_map& opts) {
topology_builder topo(e);
locator::load_stats stats;
auto make_stats = [&] {
return make_lw_shared<locator::load_stats>(stats);
};
std::vector<endpoint_dc_rack> racks;
racks.push_back(topo.rack());
for (int i = 1; i < nr_racks; ++i) {
@@ -590,9 +448,7 @@ void test_parallel_scaleout(const bpo::variables_map& opts) {
auto add_host = [&] (endpoint_dc_rack rack) {
auto host = topo.add_node(service::node_state::normal, shard_count, rack);
const uint64_t capacity = default_target_tablet_size * shard_count * 100;
stats.capacity[host] = capacity;
stats.tablet_stats[host].effective_capacity = capacity;
stats.capacity[host] = default_target_tablet_size * shard_count;
testlog.info("Added new node: {}", host);
};
@@ -610,14 +466,12 @@ void test_parallel_scaleout(const bpo::variables_map& opts) {
return add_table(e, ks1).discard_result();
}).get();
generate_tablet_sizes(tablet_size_deviation_factor, stats, e.shared_token_metadata().local());
testlog.info("Initial rebalancing");
rebalance_tablets(e, stats);
rebalance_tablets(e, make_stats());
testlog.info("Scaleout");
add_hosts(extra_nodes);
global_res.stats += rebalance_tablets(e, stats);
global_res.stats += rebalance_tablets(e, make_stats());
}, cfg).get();
}
@@ -652,7 +506,7 @@ future<> run_simulation(const params& p, const sstring& name = "") {
}
auto overcommit = res.worst.tables[i].shard_overcommit;
if (overcommit > 1.2) {
testlog.warn("[run {}] table{} shard overcommit {:.4f} > 1.2!", name, i + 1, overcommit);
testlog.warn("[run {}] table{} shard overcommit {:.2f} > 1.2!", name, i + 1, overcommit);
}
}
}
@@ -670,8 +524,6 @@ future<> run_simulations(const boost::program_options::variables_map& app_cfg) {
auto scale1 = 1 << tests::random::get_int(0, 5);
auto scale2 = 1 << tests::random::get_int(0, 5);
auto nodes = tests::random::get_int(rf1 + rf2, 2 * MAX_RF);
// results in a deviation factor of 0.0 - 2.0
auto tablet_size_deviation_factor = tests::random::get_int(0, 200) / 100.0;
params p {
.iterations = app_cfg["iterations"].as<int>(),
@@ -683,7 +535,6 @@ future<> run_simulations(const boost::program_options::variables_map& app_cfg) {
.shards = shards,
.scale1 = scale1,
.scale2 = scale2,
.tablet_size_deviation_factor = tablet_size_deviation_factor
};
auto name = format("#{}", i);
@@ -705,7 +556,6 @@ void run_add_dec(const bpo::variables_map& opts) {
.rf1 = opts["rf1"].as<int>(),
.rf2 = opts["rf2"].as<int>(),
.shards = opts["shards"].as<int>(),
.tablet_size_deviation_factor = opts["tablet-size-deviation-factor"].as<double>(),
};
run_simulation(p).get();
}
@@ -729,8 +579,7 @@ const std::map<operation, operation_func> operations_with_func{
typed_option<int>("rf1", 1, "Replication factor for the first table."),
typed_option<int>("rf2", 1, "Replication factor for the second table."),
typed_option<int>("nodes", 3, "Number of nodes in the cluster."),
typed_option<int>("shards", 30, "Number of shards per node."),
typed_option<double>("tablet-size-deviation-factor", 0.5, "Deviation factor for the tablet size random generator.")
typed_option<int>("shards", 30, "Number of shards per node.")
}
}, &run_add_dec},
@@ -743,8 +592,7 @@ const std::map<operation, operation_func> operations_with_func{
typed_option<int>("nodes-per-rack", 5, "Number of initial nodes per rack."),
typed_option<int>("extra-nodes-per-rack", 3, "Number of nodes to add per rack."),
typed_option<int>("racks", 2, "Number of racks."),
typed_option<int>("shards", 88, "Number of shards per node."),
typed_option<double>("tablet-size-deviation-factor", 0.5, "Deviation factor for the tablet size random generator.")
typed_option<int>("shards", 88, "Number of shards per node.")
}
}, &test_parallel_scaleout},
}

View File

@@ -12,7 +12,6 @@ import shlex
import subprocess
from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING
@@ -103,7 +102,7 @@ class CppFile(pytest.File, ABC):
...
@abstractmethod
def run_test_case(self, test_case: CppTestCase) -> tuple[None | list[CppTestFailure], Path]:
def run_test_case(self, test_case: CppTestCase) -> tuple[None | list[CppTestFailure], str]:
...
@cached_property
@@ -212,18 +211,8 @@ class CppTestCase(pytest.Item):
def runtest(self) -> None:
failures, output = self.parent.run_test_case(test_case=self)
# Write output to stdout so pytest captures it for both terminal and JUnit report.
# Only show the last 300 lines to avoid excessive output.
lines = get_lines_from_end(output)
if lines:
print("\n" + "=" * 70)
print("C++ Test Output (last 300 lines):")
print("=" * 70)
print('\n'.join(lines))
print("=" * 70 + "\n")
if not self.config.getoption("--save-log-on-success"):
output.unlink(missing_ok=True)
# Report the c++ output in its own sections.
self.add_report_section(when="call", key="c++", content=output)
if failures:
raise CppTestFailureList(failures)
@@ -288,31 +277,3 @@ class CppFailureRepr:
if index != len(self.failures) - 1:
tw.line(self.failure_sep, cyan=True)
def get_lines_from_end(file_path: pathlib.Path, lines_count: int = 300) -> list[str]:
"""
Seeks to the end of the file and reads backwards to find the last N lines
without iterating over the whole file.
"""
chunk_size = 8192 # 8KB chunks
buffer = ""
with file_path.open("rb") as f:
f.seek(0, os.SEEK_END)
file_size = f.tell()
pointer = file_size
while pointer > 0:
# Read one chunk backwards
pointer -= min(pointer, chunk_size)
f.seek(pointer)
chunk = f.read(min(file_size - pointer, chunk_size)).decode('utf-8', errors='ignore')
buffer = chunk + buffer
# Stop once we have enough lines
if len(buffer.splitlines()) > lines_count:
break
# Return only the requested number of lines
return buffer.splitlines()[-lines_count:]

View File

@@ -14,7 +14,6 @@ import pathlib
import json
from functools import cache, cached_property
from itertools import chain
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
from xml.etree import ElementTree
@@ -61,7 +60,7 @@ class BoostTestFile(CppFile):
return [self.test_name]
return get_boost_test_list_json_content(executable=self.exe_path,combined=self.combined).get(self.test_name, [])
def run_test_case(self, test_case: CppTestCase) -> tuple[list[CppTestFailure], Path] | tuple[None, Path]:
def run_test_case(self, test_case: CppTestCase) -> tuple[None | list[CppTestFailure], str]:
run_test = f"{self.test_name}/{test_case.test_case_name}" if self.combined else test_case.test_case_name
log_sink = tempfile.NamedTemporaryFile(mode="w+t")
@@ -87,8 +86,6 @@ class BoostTestFile(CppFile):
log_xml = pathlib.Path(log_sink.name).read_text(encoding="utf-8")
except IOError:
log_xml = ""
finally:
log_sink.close()
results = parse_boost_test_log_sink(log_xml=log_xml)
if return_code := process.returncode:
@@ -103,9 +100,13 @@ class BoostTestFile(CppFile):
command to repeat: {subprocess.list2cmdline(process.args)}
error: {results[0].lines if results else 'unknown'}
"""),
)], stdout_file_path
)], ""
return None, stdout_file_path
if not self.config.getoption("--save-log-on-success"):
log_sink.close()
stdout_file_path.unlink(missing_ok=True)
return None, ""
pytest_collect_file = BoostTestFile.pytest_collect_file

View File

@@ -8,7 +8,6 @@ from __future__ import annotations
import os
import subprocess
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
@@ -24,7 +23,7 @@ class UnitTestFile(CppFile):
def list_test_cases(self) -> list[str]:
return [self.test_name]
def run_test_case(self, test_case: CppTestCase) -> tuple[list[CppTestFailure], Path] | tuple[None, Path]:
def run_test_case(self, test_case: CppTestCase) -> tuple[None | list[CppTestFailure], str]:
stdout_file_path = test_case.get_artifact_path(extra="_stdout", suffix=".log").absolute()
process = test_case.run_exe(test_args=self.test_args, output_file=stdout_file_path)
@@ -39,9 +38,12 @@ class UnitTestFile(CppFile):
output file: {stdout_file_path}
command to repeat: {subprocess.list2cmdline(process.args)}
"""),
)], stdout_file_path
)], ""
return None, stdout_file_path
if not self.config.getoption("--save-log-on-success"):
stdout_file_path.unlink(missing_ok=True)
return None, ""
pytest_collect_file = UnitTestFile.pytest_collect_file

View File

@@ -172,11 +172,7 @@ class MinioServer:
preexec_fn=os.setsid,
stderr=self.log_file,
stdout=self.log_file,
env={
**os.environ,
'MINIO_BROWSER': 'off',
'MINIO_FS_OSYNC': 'off',
},
env={**os.environ, 'MINIO_BROWSER': 'off'},
)
timeout = time.time() + 30
while time.time() < timeout:

View File

@@ -7,8 +7,6 @@
"""
from __future__ import annotations # Type hints as strings
import asyncio
import json
import logging
import os.path
from urllib.parse import quote
@@ -18,7 +16,7 @@ from contextlib import asynccontextmanager
from typing import Any, Optional, AsyncIterator
import pytest
from aiohttp import request, BaseConnector, UnixConnector, ClientTimeout, ClientSession
from aiohttp import request, BaseConnector, UnixConnector, ClientTimeout
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
from test.pylib.internal_types import IPAddress, HostID
@@ -713,143 +711,3 @@ def get_host_api_address(host: Host) -> IPAddress:
In particular, in case the RPC address has been modified.
"""
return host.listen_address if host.listen_address else host.address
class InjectionEventStream:
"""Client for Server-Sent Events stream of error injection events.
This allows tests to wait for injection points to be hit without log parsing.
Each event contains: injection name, type (sleep/handler/exception/lambda), and shard ID.
"""
def __init__(self, node_ip: IPAddress, port: int = 10000):
self.node_ip = node_ip
self.port = port
self.session: Optional[ClientSession] = None
self._events: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._reader_task: Optional[asyncio.Task] = None
self._connected = asyncio.Event()
async def __aenter__(self):
"""Connect to SSE stream"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Disconnect from SSE stream"""
await self.disconnect()
return False
async def connect(self):
"""Establish SSE connection and start reading events"""
if self.session is not None:
return # Already connected
self.session = ClientSession()
url = f"http://{self.node_ip}:{self.port}/v2/error_injection/events"
# Start background task to read SSE events
self._reader_task = asyncio.create_task(self._read_events(url))
# Wait for connection to be established
await asyncio.wait_for(self._connected.wait(), timeout=10.0)
logger.info(f"Connected to injection event stream at {url}")
async def disconnect(self):
"""Close SSE connection"""
if self._reader_task:
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
self._reader_task = None
if self.session:
await self.session.close()
self.session = None
async def _read_events(self, url: str):
"""Background task to read SSE events"""
try:
async with self.session.get(url, timeout=ClientTimeout(total=None)) as resp:
if resp.status != 200:
logger.error(f"Failed to connect to SSE stream: {resp.status}")
return
# Signal connection established
self._connected.set()
# Read SSE events line by line
async for line in resp.content:
line = line.decode('utf-8').strip()
# SSE format: "data: <json>"
if line.startswith('data: '):
json_str = line[6:] # Remove "data: " prefix
try:
event = json.loads(json_str)
await self._events.put(event)
logger.debug(f"Received injection event: {event}")
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse SSE event: {json_str}, error: {e}")
elif line.startswith(':'):
# SSE comment (connection keepalive)
pass
except asyncio.CancelledError:
logger.debug("SSE reader task cancelled")
raise
except Exception as e:
logger.error(f"Error reading SSE stream: {e}", exc_info=True)
async def wait_for_injection(self, injection_name: str, timeout: float = 30.0) -> dict[str, Any]:
"""Wait for a specific injection to be triggered.
Args:
injection_name: Name of the injection to wait for
timeout: Maximum time to wait in seconds
Returns:
Event dictionary with keys: injection, type, shard
Raises:
asyncio.TimeoutError: If injection not triggered within timeout
"""
deadline = asyncio.get_event_loop().time() + timeout
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
raise asyncio.TimeoutError(
f"Injection '{injection_name}' not triggered within {timeout}s"
)
try:
event = await asyncio.wait_for(self._events.get(), timeout=remaining)
if event.get('injection') == injection_name:
return event
# Not the injection we're waiting for, continue
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
f"Injection '{injection_name}' not triggered within {timeout}s"
)
@asynccontextmanager
async def injection_event_stream(node_ip: IPAddress, port: int = 10000) -> AsyncIterator[InjectionEventStream]:
"""Context manager for error injection event stream.
Usage:
async with injection_event_stream(node_ip) as stream:
await api.enable_injection(node_ip, "my_injection", one_shot=True)
# Start operation that will trigger injection
event = await stream.wait_for_injection("my_injection", timeout=30)
logger.info(f"Injection triggered on shard {event['shard']}")
"""
stream = InjectionEventStream(node_ip, port)
try:
await stream.connect()
yield stream
finally:
await stream.disconnect()

View File

@@ -196,7 +196,7 @@ def pytest_sessionstart(session: pytest.Session) -> None:
)
@pytest.hookimpl(tryfirst=True)
@pytest.hookimpl(trylast=True)
def pytest_runtest_logreport(report):
"""Add custom XML attributes to JUnit testcase elements.
@@ -208,7 +208,7 @@ def pytest_runtest_logreport(report):
Attributes added:
- function_path: The function path of the test case (excluding parameters).
Uses tryfirst=True to run before LogXML's hook has created the node_reporter to avoid double recording.
Uses trylast=True to run after LogXML's hook has created the node_reporter.
"""
# Get the XML reporter
config = _pytest_config

View File

@@ -216,13 +216,9 @@ async def with_file_lock(lock_path: pathlib.Path) -> AsyncIterator[None]:
async def get_scylla_2025_1_executable(build_mode: str) -> str:
async def run_process(cmd, **kwargs):
proc = await asyncio.create_subprocess_exec(
*cmd, stderr=asyncio.subprocess.PIPE, **kwargs)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(
f"Command {cmd} failed with exit code {proc.returncode}: {stderr.decode(errors='replace').strip()}"
)
proc = await asyncio.create_subprocess_exec(*cmd, **kwargs)
await proc.communicate()
assert proc.returncode == 0
is_debug = build_mode == 'debug' or build_mode == 'sanitize'
package = "scylla-debug" if is_debug else "scylla"
@@ -249,7 +245,7 @@ async def get_scylla_2025_1_executable(build_mode: str) -> str:
if not unpacked_marker.exists():
if not downloaded_marker.exists():
archive_path.unlink(missing_ok=True)
await run_process(["curl", "--retry", "40", "--retry-max-time", "60", "--fail", "--silent", "--show-error", "--retry-all-errors", "--output", archive_path, url])
await run_process(["curl", "--retry", "10", "--fail", "--silent", "--show-error", "--output", archive_path, url])
downloaded_marker.touch()
shutil.rmtree(unpack_dir, ignore_errors=True)
unpack_dir.mkdir(exist_ok=True, parents=True)

View File

@@ -260,7 +260,7 @@ class PythonTest(Test):
self.is_before_test_ok = True
cluster.take_log_savepoint()
yield cluster
yield
if self.shortname in self.suite.dirties_cluster:
cluster.is_dirty = True

View File

@@ -2,9 +2,6 @@
asyncio_mode = auto
asyncio_default_fixture_loop_scope = session
junit_logging = all
junit_log_passing_tests = False
log_format = %(asctime)s.%(msecs)03d %(levelname)s> %(message)s
log_date_format = %H:%M:%S

View File

@@ -3,73 +3,93 @@
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
"""Conftest for Scylla GDB tests"""
import logging
import os
import subprocess
import pexpect
import pytest
import re
from test.pylib.runner import testpy_test_fixture_scope
from test.pylib.suite.python import PythonTest
@pytest.fixture(scope=testpy_test_fixture_scope)
async def scylla_server(testpy_test: PythonTest | None):
"""Return a running Scylla server instance from the active test cluster."""
async with testpy_test.run_ctx(options=testpy_test.suite.options) as cluster:
yield next(iter(cluster.running.values()))
from test.pylib.util import LogPrefixAdapter
@pytest.fixture(scope="module")
def gdb_cmd(scylla_server, request):
"""
Returns a command-line (argv list) that attaches to the `scylla_server` PID, loads `scylla-gdb.py`
and `gdb_utils.py`. This is meant to be executed by `execute_gdb_command()` in `--batch` mode.
async def scylla_server(testpy_test: PythonTest | None):
"""Return a running Scylla server instance from the active test cluster."""
logger_prefix = testpy_test.mode + "/" + testpy_test.uname
logger = LogPrefixAdapter(
logging.getLogger(logger_prefix), {"prefix": logger_prefix}
)
scylla_cluster = await testpy_test.suite.clusters.get(logger)
scylla_server = next(iter(scylla_cluster.running.values()))
yield scylla_server
await testpy_test.suite.clusters.put(scylla_cluster, is_dirty=True)
@pytest.fixture(scope="module")
def gdb_process(scylla_server, request):
"""Spawn an interactive GDB attached to the Scylla process.
Loads `scylla-gdb.py` and test helpers (`gdb_utils.py`) so tests can run GDB/Python helpers
against the live Scylla process.
"""
scylla_gdb_py = os.path.join(request.fspath.dirname, "..", "..", "scylla-gdb.py")
script_py = os.path.join(request.fspath.dirname, "gdb_utils.py")
cmd = [
"gdb",
"-q",
"--batch",
"--nx",
"-se",
str(scylla_server.exe),
"-p",
str(scylla_server.cmd.pid),
"-ex",
"set python print-stack full",
"-x",
scylla_gdb_py,
"-x",
script_py,
]
return cmd
cmd = (
f"gdb -q "
"--nx "
"-iex 'set confirm off' "
"-iex 'set pagination off' "
f"-se {scylla_server.exe} "
f"-p {scylla_server.cmd.pid} "
f"-ex set python print-stack full "
f"-x {scylla_gdb_py} "
f"-x {script_py}"
)
gdb_process = pexpect.spawn(cmd, maxread=10, searchwindowsize=10)
gdb_process.expect_exact("(gdb)")
yield gdb_process
gdb_process.terminate()
def execute_gdb_command(gdb_cmd, scylla_command: str = None, full_command: str = None):
"""Execute a single GDB command attached to the running Scylla process.
def execute_gdb_command(
gdb_process, scylla_command: str = None, full_command: str = None
):
"""
Execute a command in an interactive GDB session and return its output.
Builds on `gdb_cmd` and runs GDB via `subprocess.run()` in `--batch` mode.
`scylla_command` is executed as `scylla <cmd>` through GDB's Python interface.
The command can be provided either as a Scylla GDB command (which will be
wrapped and executed via GDB's Python interface) or as a full raw GDB
command string.
The function waits for the GDB prompt to reappear, enforces a timeout,
and fails the test if the command does not complete or if GDB reports an
error.
Args:
gdb_cmd: Base GDB argv list returned by the `gdb_cmd` fixture.
scylla_command: Scylla GDB command name/args (from scylla-gdb.py). Mutually exclusive with `full_command`.
full_command: Raw GDB command string to execute. Mutually exclusive with `scylla_command`.
Returns:
Command stdout as a decoded string.
gdb_process (pexpect.pty_spawn.spawn): An active GDB process spawned via pexpect
scylla_command (str, optional): A GDB Scylla command (from scylla-gdb.py) to execute.
full_command (str, optional): A raw GDB command string to execute.
"""
command = f"python gdb.execute('scylla {scylla_command}')"
if full_command:
command = [*gdb_cmd, "-ex", full_command]
else:
command = [
*gdb_cmd,
"-ex",
f"python gdb.execute('scylla {scylla_command}')",
]
command = full_command
result = subprocess.run(
command, capture_output=True, text=True, encoding="utf-8", errors="replace"
)
gdb_process.sendline(command)
try:
gdb_process.expect_exact("(gdb)", timeout=180)
except pexpect.exceptions.TIMEOUT:
gdb_process.sendcontrol("c")
gdb_process.expect_exact("(gdb)", timeout=1)
pytest.fail("GDB command did not complete within the timeout period")
result = gdb_process.before.decode("utf-8")
# The task_histogram command may include "error::Error" in its output, so
# allow it.
assert not re.search(r'(?<!error::)Error', result)
return result

View File

@@ -26,7 +26,6 @@ pytestmark = [
@pytest.mark.parametrize(
"command",
[
"timers",
"features",
"compaction-tasks",
"databases",
@@ -59,20 +58,19 @@ pytestmark = [
"task_histogram -a",
"tasks",
"threads",
"timers",
"get-config-value compaction_static_shares",
"read-stats",
"prepared-statements",
],
)
def test_scylla_commands(gdb_cmd, command):
result = execute_gdb_command(gdb_cmd, command)
assert result.returncode == 0, (
f"GDB command {command} failed. stdout: {result.stdout} stderr: {result.stderr}"
)
def test_scylla_commands(gdb_process, command):
execute_gdb_command(gdb_process, command)
def test_nonexistent_scylla_command(gdb_cmd):
def test_nonexistent_scylla_command(gdb_process):
"""Verifies that running unknown command will produce correct error message"""
result = execute_gdb_command(gdb_cmd, "nonexistent_command")
assert result.returncode == 1
assert "Undefined scylla command: \"nonexistent_command\"" in result.stderr
with pytest.raises(
AssertionError, match=r'Undefined scylla command: "nonexistent_command"'
):
execute_gdb_command(gdb_process, "nonexistent_command")

View File

@@ -25,13 +25,14 @@ pytestmark = [
@pytest.fixture(scope="module")
def schema(gdb_cmd):
def schema(gdb_process):
"""
Returns pointer to schema of the first table it finds
Even without any user tables, we will always have system tables.
"""
result = execute_gdb_command(gdb_cmd, full_command="python get_schema()").stdout
result = execute_gdb_command(gdb_process, full_command="python get_schema()")
match = re.search(r"schema=\s*(0x[0-9a-fA-F]+)", result)
assert match, f"Failed to find schema pointer in response: {result}"
schema_pointer = match.group(1) if match else None
return schema_pointer
@@ -45,22 +46,12 @@ def schema(gdb_cmd):
"schema (const schema *)", # `schema` requires type-casted pointer
],
)
def test_schema(gdb_cmd, command, schema):
assert schema, "Failed to find schema of any table"
result = execute_gdb_command(gdb_cmd, f"{command} {schema}")
assert result.returncode == 0, (
f"GDB command {command} failed. stdout: {result.stdout} stderr: {result.stderr}"
)
def test_schema(gdb_process, command, schema):
execute_gdb_command(gdb_process, f"{command} {schema}")
def test_generate_object_graph(gdb_cmd, schema, request):
assert schema, "Failed to find schema of any table"
def test_generate_object_graph(gdb_process, schema, request):
tmpdir = request.config.getoption("--tmpdir")
result = execute_gdb_command(
gdb_cmd, f"generate-object-graph -o {tmpdir}/og.dot -d 2 -t 10 {schema}"
)
assert result.returncode == 0, (
f"GDB command `generate-object-graph` failed. stdout: {result.stdout} stderr: {result.stderr}"
execute_gdb_command(
gdb_process, f"generate-object-graph -o {tmpdir}/og.dot -d 2 -t 10 {schema}"
)

View File

@@ -25,10 +25,11 @@ pytestmark = [
@pytest.fixture(scope="module")
def sstable(gdb_cmd):
def sstable(gdb_process):
"""Finds sstable"""
result = execute_gdb_command(gdb_cmd, full_command="python get_sstables()").stdout
result = execute_gdb_command(gdb_process, full_command="python get_sstables()")
match = re.search(r"(\(sstables::sstable \*\) 0x)([0-9a-f]+)", result)
assert match is not None, "No sstable was present in result.stdout"
sstable_pointer = match.group(0).strip() if match else None
return sstable_pointer
@@ -41,10 +42,5 @@ def sstable(gdb_cmd):
"sstable-index-cache",
],
)
def test_sstable(gdb_cmd, command, sstable):
assert sstable, "No sstable was found"
result = execute_gdb_command(gdb_cmd, f"{command} {sstable}")
assert result.returncode == 0, (
f"GDB command {command} failed. stdout: {result.stdout} stderr: {result.stderr}"
)
def test_sstable(gdb_process, command, sstable):
execute_gdb_command(gdb_process, f"{command} {sstable}")

View File

@@ -26,7 +26,7 @@ pytestmark = [
@pytest.fixture(scope="module")
def task(gdb_cmd):
def task(gdb_process):
"""
Finds a Scylla fiber task using a `find_vptrs()` loop.
@@ -35,14 +35,19 @@ def task(gdb_cmd):
skeleton created by `http_server::do_accept_one` (often the earliest
“Scylla fiber” to appear).
"""
result = execute_gdb_command(gdb_cmd, full_command="python get_task()").stdout
result = execute_gdb_command(gdb_process, full_command="python get_task()")
match = re.search(r"task=(\d+)", result)
assert match is not None, f"No task was present in {result.stdout}"
task = match.group(1) if match else None
return task
def test_fiber(gdb_process, task):
execute_gdb_command(gdb_process, f"fiber {task}")
@pytest.fixture(scope="module")
def coroutine_task(gdb_cmd, scylla_server):
def coroutine_task(gdb_process, scylla_server):
"""
Finds a coroutine task, similar to the `task` fixture.
@@ -54,11 +59,11 @@ def coroutine_task(gdb_cmd, scylla_server):
diagnostic information before the test is marked as failed.
Coredump is saved to `testlog/release/{scylla}`.
"""
result = execute_gdb_command(gdb_cmd, full_command="python get_coroutine()").stdout
result = execute_gdb_command(gdb_process, full_command="python get_coroutine()")
match = re.search(r"coroutine_config=\s*(.*)", result)
if not match:
result = execute_gdb_command(
gdb_cmd,
gdb_process,
full_command=f"python coroutine_debug_config('{scylla_server.workdir}')",
)
pytest.fail(
@@ -69,26 +74,12 @@ def coroutine_task(gdb_cmd, scylla_server):
return match.group(1).strip()
def test_coroutine_frame(gdb_cmd, coroutine_task):
def test_coroutine_frame(gdb_process, coroutine_task):
"""
Offsets the pointer by two words to shift from the outer coroutine frame
to the inner `seastar::task`, as required by `$coro_frame`, which expects
a `seastar::task*`.
"""
assert coroutine_task, "No coroutine task was found"
result = execute_gdb_command(
gdb_cmd, full_command=f"p *$coro_frame({coroutine_task} + 16)"
)
assert result.returncode == 0, (
f"GDB command `coro_frame` failed. stdout: {result.stdout} stderr: {result.stderr}"
)
def test_fiber(gdb_cmd, task):
assert task, f"No task was found using `find_vptrs()`"
result = execute_gdb_command(gdb_cmd, f"fiber {task}")
assert result.returncode == 0, (
f"GDB command `fiber` failed. stdout: {result.stdout} stderr: {result.stderr}"
execute_gdb_command(
gdb_process, full_command=f"p *$coro_frame({coroutine_task} + 16)"
)

View File

@@ -43,15 +43,10 @@ class random_content_file:
os.unlink(self.filename)
CRITICAL_DISK_UTILIZATION_LEVEL = 0.5
# Target disk fill ratio used in tests to push the node above the critical
# utilization level.
DISK_FILL_TARGET_RATIO = 1.1 * CRITICAL_DISK_UTILIZATION_LEVEL
# Since we create 20M volumes, we need to reduce the commitlog segment size
# otherwise we hit out of space.
global_cmdline = ["--disk-space-monitor-normal-polling-interval-in-seconds", "1",
"--critical-disk-utilization-level", f"{CRITICAL_DISK_UTILIZATION_LEVEL}",
"--critical-disk-utilization-level", "0.8",
"--commitlog-segment-size-in-mb", "2",
"--schema-commitlog-segment-size-in-mb", "4",
"--tablet-load-stats-refresh-interval-in-seconds", "1",
@@ -85,7 +80,7 @@ async def test_user_writes_rejection(manager: ManagerClient, volumes_factory: Ca
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("database - Set critical disk utilization mode: true", from_mark=mark)
@@ -140,7 +135,7 @@ async def test_autotoogle_compaction(manager: ManagerClient, volumes_factory: Ca
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("compaction_manager - Drained", from_mark=mark)
@@ -200,7 +195,7 @@ async def test_critical_utilization_during_decommission(manager: ManagerClient,
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
mark, _ = await log.wait_for("Reached the critical disk utilization level", from_mark=mark)
mark, _ = await log.wait_for("Refreshing table load stats", from_mark=mark)
mark, _ = await log.wait_for("Refreshed table load stats", from_mark=mark)
@@ -236,7 +231,7 @@ async def test_reject_split_compaction(manager: ManagerClient, volumes_factory:
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
await log.wait_for(f"Split task .* for table {cf} .* stopped, reason: Compaction for {cf} was stopped due to: drain")
@@ -261,7 +256,7 @@ async def test_split_compaction_not_triggered(manager: ManagerClient, volumes_fa
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
s1_mark, _ = await s1_log.wait_for("compaction_manager - Drained", from_mark=s1_mark)
@@ -296,7 +291,7 @@ async def test_tablet_repair(manager: ManagerClient, volumes_factory: Callable)
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("repair - Drained", from_mark=mark)
@@ -372,7 +367,7 @@ async def test_autotoogle_reject_incoming_migrations(manager: ManagerClient, vol
mark = await log.mark()
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("database - Set critical disk utilization mode: true", from_mark=mark)
@@ -427,7 +422,7 @@ async def test_node_restart_while_tablet_split(manager: ManagerClient, volumes_f
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("compaction_manager - Drained", from_mark=mark)
@@ -510,7 +505,7 @@ async def test_repair_failure_on_split_rejection(manager: ManagerClient, volumes
logger.info("Create a big file on the target node to reach critical disk utilization level")
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*DISK_FILL_TARGET_RATIO) - disk_info.used):
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
for _ in range(2):
mark, _ = await log.wait_for("compaction_manager - Drained", from_mark=mark)
@@ -529,93 +524,3 @@ async def test_repair_failure_on_split_rejection(manager: ManagerClient, volumes
await repair_task
mark, _ = await log.wait_for(f"Detected tablet split for table {cf}", from_mark=mark)
# Since we create 20M volumes, we need to reduce the commitlog segment size
# otherwise we hit out of space.
global_cmdline_with_disabled_monitor = [
"--disk-space-monitor-normal-polling-interval-in-seconds", "1",
"--critical-disk-utilization-level", "1.0",
"--commitlog-segment-size-in-mb", "2",
"--schema-commitlog-segment-size-in-mb", "4",
"--tablet-load-stats-refresh-interval-in-seconds", "1",
]
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_sstables_incrementally_released_during_streaming(manager: ManagerClient, volumes_factory: Callable) -> None:
"""
Test that source node will not run out of space if major compaction rewrites the sstables being streamed.
Expects the file streaming and major will both release sstables incrementally, reducing chances of 2x
space amplification.
Scenario:
- Create a 2-node cluster with limited disk space.
- Create a table with 2 tablets, one in each node
- Write 20% of node capacity to each tablet.
- Start decommissioning one node.
- During streaming, create a large file on the source node to push it over 85%
- Run major expecting the file streaming released the sstables incrementally. Had it not, source node runs out of space.
- Unblock streaming
- Verify that the decommission operation succeeds.
"""
cmdline = [*global_cmdline_with_disabled_monitor,
"--logger-log-level", "load_balancer=debug",
"--logger-log-level", "debug_error_injection=debug"
]
# the coordinator needs more space, so creating a 40M volume for it.
async with space_limited_servers(manager, volumes_factory, ["40M", "20M"], cmdline=cmdline,
property_file=[{"dc": "dc1", "rack": "r1"}]*2) as servers:
cql, _ = await manager.get_ready_cql(servers)
workdir = await manager.server_get_workdir(servers[1].server_id)
log = await manager.server_open_log(servers[1].server_id)
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'dc1': ['{servers[1].rack}'] }}"
" AND tablets = {'initial': 2}") as ks:
await manager.disable_tablet_balancing()
# Needs 1mb fragments in order to stress incremental release in file streaming
extra_table_param = "WITH compaction = {'class' : 'IncrementalCompactionStrategy', 'sstable_size_in_mb' : '1'} and compression = {}"
async with new_test_table(manager, ks, "pk int PRIMARY KEY, t text", extra_table_param) as cf:
before_disk_info = psutil.disk_usage(workdir)
# About 4mb per tablet
await asyncio.gather(*[cql.run_async(query) for query in write_generator(cf, 8000)])
# split data into 1mb fragments
await manager.api.keyspace_flush(servers[1].ip_addr, ks)
await manager.api.keyspace_compaction(servers[1].ip_addr, ks)
after_disk_info = psutil.disk_usage(workdir)
percent_by_writes = after_disk_info.percent - before_disk_info.percent
logger.info(f"Percent taken by writes {percent_by_writes}")
# assert sstable data content account for more than 20% of node's storage.
assert percent_by_writes > 20
# We want to trap only migrations which happened during decommission
await manager.api.quiesce_topology(servers[0].ip_addr)
await manager.api.enable_injection(servers[1].ip_addr, "tablet_stream_files_end_wait", one_shot=True)
mark = await log.mark()
logger.info(f"Workdir {workdir}")
decomm_task = asyncio.create_task(manager.decommission_node(servers[1].server_id))
await manager.enable_tablet_balancing()
mark, _ = await log.wait_for("tablet_stream_files_end_wait: waiting", from_mark=mark)
disk_info = psutil.disk_usage(workdir)
with random_content_file(workdir, int(disk_info.total*0.85) - disk_info.used):
disk_info = psutil.disk_usage(workdir)
logger.info(f"Percent used before major {disk_info.percent}")
# Run major in order to try to reproduce 2x space amplification if files aren't released
# incrementally by streamer.
await manager.api.keyspace_compaction(servers[1].ip_addr, ks)
await asyncio.gather(*[cql.run_async(query) for query in write_generator(cf, 100)])
disk_info = psutil.disk_usage(workdir)
logger.info(f"Percent used after major {disk_info.percent}")
await manager.api.message_injection(servers[1].ip_addr, "tablet_stream_files_end_wait")
await decomm_task

View File

@@ -1044,16 +1044,7 @@ SEASTAR_TEST_CASE(vector_store_client_https_rewrite_ca_cert) {
std::filesystem::copy_file(
std::string(certs.ca_cert_file()), std::string(broken_cert.get_path().string()), std::filesystem::copy_options::overwrite_existing);
// Wait for the truststore to reload the updated cert on all shards before attempting ANN requests.
// This avoids a race where an ANN request initiates a TLS handshake using the old (broken) credentials
// while the reload is still in progress, which can cause a long hang due to TLS handshake timeout.
co_await env.vector_store_client().invoke_on_all([&](this auto, vector_store_client& vs) -> future<> {
BOOST_CHECK(co_await repeat_until([&]() -> future<bool> {
co_return vector_store_client_tester::truststore_reload_count(vs) >= 1;
}));
});
// Wait for the client to succeed with the reloaded CA cert
// Wait for the client to reload the CA cert and succeed
co_await env.vector_store_client().invoke_on_all([&](this auto, vector_store_client& vs) -> future<> {
auto schema = env.local_db().find_schema("ks", "idx");
auto as = abort_source_timeout();

View File

@@ -27,7 +27,6 @@ target_sources(utils
hashers.cc
histogram_metrics_helper.cc
http.cc
http_client_error_processing.cc
human_readable.cc
i_filter.cc
io-wrappers.cc

View File

@@ -41,11 +41,6 @@ extern logging::logger errinj_logger;
using error_injection_parameters = std::unordered_map<sstring, sstring>;
// Callback type for error injection events
// Called when an injection point is triggered
// Parameters: injection_name, injection_type ("sleep", "exception", "handler", "lambda")
using error_injection_event_callback = std::function<void(std::string_view, std::string_view)>;
// Wraps the argument to breakpoint injection (see the relevant inject() overload
// in class error_injection below). Parameters:
// timeout - the timeout after which the pause is aborted
@@ -333,21 +328,6 @@ private:
// Map enabled-injection-name -> is-one-shot
std::unordered_map<std::string_view, injection_data> _enabled;
// Event callbacks to notify when injections are triggered
std::vector<error_injection_event_callback> _event_callbacks;
// Notify all registered event callbacks
void notify_event(std::string_view injection_name, std::string_view injection_type) {
for (const auto& callback : _event_callbacks) {
try {
callback(injection_name, injection_type);
} catch (...) {
errinj_logger.warn("Error injection event callback failed for \"{}\": {}",
injection_name, std::current_exception());
}
}
}
bool is_one_shot(const std::string_view& injection_name) const {
const auto it = _enabled.find(injection_name);
if (it == _enabled.end()) {
@@ -417,17 +397,6 @@ public:
| std::ranges::to<std::vector<sstring>>();
}
// \brief Register an event callback to be notified when injections are triggered
// \param callback function to call when injection is triggered
void register_event_callback(error_injection_event_callback callback) {
_event_callbacks.push_back(std::move(callback));
}
// \brief Clear all registered event callbacks
void clear_event_callbacks() {
_event_callbacks.clear();
}
// \brief Inject a lambda call
// \param f lambda to be run
[[gnu::always_inline]]
@@ -435,8 +404,7 @@ public:
if (!enter(name)) {
return;
}
errinj_logger.info("Triggering injection \"{}\"", name);
notify_event(name, "lambda");
errinj_logger.debug("Triggering injection \"{}\"", name);
f();
}
@@ -446,8 +414,7 @@ public:
if (!enter(name)) {
return make_ready_future<>();
}
errinj_logger.info("Triggering sleep injection \"{}\" ({}ms)", name, duration.count());
notify_event(name, "sleep");
errinj_logger.debug("Triggering sleep injection \"{}\" ({}ms)", name, duration.count());
return seastar::sleep(duration);
}
@@ -457,8 +424,7 @@ public:
if (!enter(name)) {
return make_ready_future<>();
}
errinj_logger.info("Triggering abortable sleep injection \"{}\" ({}ms)", name, duration.count());
notify_event(name, "sleep");
errinj_logger.debug("Triggering abortable sleep injection \"{}\" ({}ms)", name, duration.count());
return seastar::sleep_abortable(duration, as);
}
@@ -472,8 +438,7 @@ public:
// Time left until deadline
auto duration = deadline - Clock::now();
errinj_logger.info("Triggering sleep injection \"{}\" ({})", name, duration);
notify_event(name, "sleep");
errinj_logger.debug("Triggering sleep injection \"{}\" ({})", name, duration);
return seastar::sleep<Clock>(duration);
}
@@ -488,8 +453,7 @@ public:
return make_ready_future<>();
}
errinj_logger.info("Triggering exception injection \"{}\"", name);
notify_event(name, "exception");
errinj_logger.debug("Triggering exception injection \"{}\"", name);
return make_exception_future<>(exception_factory());
}
@@ -509,8 +473,7 @@ public:
co_return;
}
errinj_logger.info("Triggering injection \"{}\" with injection handler", name);
notify_event(name, "handler");
errinj_logger.debug("Triggering injection \"{}\" with injection handler", name);
injection_handler handler(data->shared_data, share_messages);
data->handlers.push_back(handler);
@@ -616,22 +579,6 @@ public:
return errinj.enabled_injections();
}
// \brief Register an event callback on all shards
static future<> register_event_callback_on_all(error_injection_event_callback callback) {
return smp::invoke_on_all([callback = std::move(callback)] {
auto& errinj = _local;
errinj.register_event_callback(callback);
});
}
// \brief Clear all event callbacks on all shards
static future<> clear_event_callbacks_on_all() {
return smp::invoke_on_all([] {
auto& errinj = _local;
errinj.clear_event_callbacks();
});
}
static error_injection& get_local() {
return _local;
}
@@ -759,22 +706,6 @@ public:
[[gnu::always_inline]]
static std::vector<sstring> enabled_injections_on_all() { return {}; }
[[gnu::always_inline]]
void register_event_callback(error_injection_event_callback callback) {}
[[gnu::always_inline]]
void clear_event_callbacks() {}
[[gnu::always_inline]]
static future<> register_event_callback_on_all(error_injection_event_callback callback) {
return make_ready_future<>();
}
[[gnu::always_inline]]
static future<> clear_event_callbacks_on_all() {
return make_ready_future<>();
}
static error_injection& get_local() {
return _local;
}

View File

@@ -26,7 +26,6 @@
#include <seastar/core/align.hh>
#include <functional>
#include <optional>
#include <system_error>
#include <type_traits>
@@ -212,75 +211,3 @@ inline std::exception_ptr make_nested_exception_ptr(Ex&& ex, std::exception_ptr
}
#endif
}
namespace exception::internal {
template <typename F>
struct lambda_arg;
template <typename R, typename C, typename Arg>
struct lambda_arg<R (C::*)(Arg) const> {
using type = Arg;
};
template <typename F>
using lambda_arg_t = std::remove_cvref_t<typename lambda_arg<decltype(&F::operator())>::type>;
} // namespace exception::internal
// dispatch_exception: unwraps nested exceptions (if any) and applies handlers
// The dispatcher gets as input the exception_ptr to process, a default handler
// to call if no other handler matches, and a variadic list of TypedHandlers.
// All handlers (including the default one) must return the same type R.
template <typename R, typename DefaultHandler, typename... Handlers>
requires std::is_same_v<R, std::invoke_result_t<DefaultHandler, std::exception_ptr, std::string&&>> &&
(std::is_same_v<R, std::invoke_result_t<Handlers, const exception::internal::lambda_arg_t<Handlers>&>> && ...)
R dispatch_exception(std::exception_ptr eptr, DefaultHandler&& default_handler, Handlers&&... handlers) {
std::string original_message;
while (eptr) {
try {
std::rethrow_exception(eptr);
} catch (const std::exception& e) {
if (original_message.empty()) {
original_message = e.what();
}
std::optional<R> result;
(
[&] {
using F = std::decay_t<Handlers>;
using Arg = exception::internal::lambda_arg_t<F>;
if constexpr (std::is_base_of_v<std::exception, Arg>) {
if (!result) {
if (auto* casted = dynamic_cast<const Arg*>(&e)) {
result = handlers(*casted);
}
}
}
}(),
...);
if (result) {
return *result;
}
// Try to unwrap nested exception
try {
std::rethrow_if_nested(e);
} catch (...) {
eptr = std::current_exception();
continue;
}
return default_handler(eptr, std::move(original_message));
} catch (...) {
return default_handler(eptr, std::move(original_message));
}
}
return default_handler(eptr, std::move(original_message));
}

View File

@@ -429,7 +429,7 @@ future<> utils::gcp::storage::client::object_data_sink::acquire_session() {
}
auto path = fmt::format("/upload/storage/v1/b/{}/o?uploadType=resumable&name={}"
, _bucket
, seastar::http::internal::url_encode(_object_name)
, _object_name
);
auto reply = co_await _impl->send_with_retry(path
@@ -689,11 +689,7 @@ future<temporary_buffer<char>> utils::gcp::storage::client::object_data_source::
}
// Ensure we read from the same generation as we queried in read_info. Note: mock server ignores this.
auto path = fmt::format("/storage/v1/b/{}/o/{}?ifGenerationMatch={}&alt=media"
, _bucket
, seastar::http::internal::url_encode(_object_name)
, _generation
);
auto path = fmt::format("/storage/v1/b/{}/o/{}?ifGenerationMatch={}&alt=media", _bucket, _object_name, _generation);
auto range = fmt::format("bytes={}-{}", _position, _position+to_read-1); // inclusive range
co_await _impl->send_with_retry(path
@@ -803,7 +799,7 @@ future<temporary_buffer<char>> utils::gcp::storage::client::object_data_source::
future<> utils::gcp::storage::client::object_data_source::read_info() {
gcp_storage.debug("Read info {}:{}", _bucket, _object_name);
auto path = fmt::format("/storage/v1/b/{}/o/{}", _bucket, seastar::http::internal::url_encode(_object_name));
auto path = fmt::format("/storage/v1/b/{}/o/{}", _bucket, _object_name);
auto res = co_await _impl->send_with_retry(path
, GCP_OBJECT_SCOPE_READ_ONLY
@@ -920,12 +916,6 @@ static utils::gcp::storage::object_info create_info(const rjson::value& item) {
// point in it. Return chunked_vector to avoid large alloc, but keep it
// in one object... for now...
future<utils::chunked_vector<utils::gcp::storage::object_info>> utils::gcp::storage::client::list_objects(std::string_view bucket_in, std::string_view prefix, bucket_paging& pager) {
utils::chunked_vector<utils::gcp::storage::object_info> result;
if (pager.done) {
co_return result;
}
std::string bucket(bucket_in);
gcp_storage.debug("List bucket {} (prefix={}, max_results={})", bucket, prefix, pager.max_results);
@@ -945,6 +935,8 @@ future<utils::chunked_vector<utils::gcp::storage::object_info>> utils::gcp::stor
psep = "&&";
}
utils::chunked_vector<utils::gcp::storage::object_info> result;
co_await _impl->send_with_retry(path
, GCP_OBJECT_SCOPE_READ_ONLY
, ""s
@@ -973,7 +965,6 @@ future<utils::chunked_vector<utils::gcp::storage::object_info>> utils::gcp::stor
}
pager.token = rjson::get_opt<std::string>(root, "nextPageToken").value_or(""s);
pager.done = pager.token.empty();
for (auto& item : items->GetArray()) {
object_info info = create_info(item);
@@ -998,7 +989,7 @@ future<> utils::gcp::storage::client::delete_object(std::string_view bucket_in,
gcp_storage.debug("Delete object {}:{}", bucket, object_name);
auto path = fmt::format("/storage/v1/b/{}/o/{}", bucket, seastar::http::internal::url_encode(object_name));
auto path = fmt::format("/storage/v1/b/{}/o/{}", bucket, object_name);
auto res = co_await _impl->send_with_retry(path
, GCP_OBJECT_SCOPE_READ_WRITE
@@ -1035,11 +1026,7 @@ future<> utils::gcp::storage::client::rename_object(std::string_view bucket_in,
gcp_storage.debug("Move object {}:{} -> {}", bucket, object_name, new_name);
auto path = fmt::format("/storage/v1/b/{}/o/{}/moveTo/o/{}"
, bucket
, seastar::http::internal::url_encode(object_name)
, seastar::http::internal::url_encode(new_name)
);
auto path = fmt::format("/storage/v1/b/{}/o/{}/moveTo/o/{}", bucket, object_name, new_name);
auto res = co_await _impl->send_with_retry(path
, GCP_OBJECT_SCOPE_READ_WRITE
, ""s
@@ -1065,12 +1052,7 @@ future<> utils::gcp::storage::client::rename_object(std::string_view bucket_in,
future<> utils::gcp::storage::client::copy_object(std::string_view bucket_in, std::string_view object_name_in, std::string_view new_bucket_in, std::string_view to_name_in) {
std::string bucket(bucket_in), object_name(object_name_in), new_bucket(new_bucket_in), to_name(to_name_in);
auto path = fmt::format("/storage/v1/b/{}/o/{}/rewriteTo/b/{}/o/{}"
, bucket
, seastar::http::internal::url_encode(object_name)
, new_bucket
, seastar::http::internal::url_encode(to_name)
);
auto path = fmt::format("/storage/v1/b/{}/o/{}/rewriteTo/b/{}/o/{}", bucket, object_name, new_bucket, to_name);
std::string body;
for (;;) {
@@ -1123,7 +1105,7 @@ future<utils::gcp::storage::object_info> utils::gcp::storage::client::merge_obje
std::string bucket(bucket_in), object_name(dest_object_name);
auto path = fmt::format("/storage/v1/b/{}/o/{}/compose", bucket, seastar::http::internal::url_encode(object_name));
auto path = fmt::format("/storage/v1/b/{}/o/{}/compose", bucket, object_name);
auto body = rjson::print(compose);
auto res = co_await _impl->send_with_retry(path

View File

@@ -49,12 +49,10 @@ namespace utils::gcp::storage {
private:
uint32_t max_results;
std::string token;
bool done;
friend class client;
public:
bucket_paging(uint64_t max = 1000)
: max_results(max)
, done(false)
{}
bucket_paging(const bucket_paging&) = delete;
bucket_paging(bucket_paging&&) = default;

View File

@@ -1,66 +0,0 @@
/*
* Copyright (C) 2026-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "http_client_error_processing.hh"
#include <seastar/http/exception.hh>
#include <gnutls/gnutls.h>
namespace utils::http {
retryable from_http_code(seastar::http::reply::status_type http_code) {
switch (http_code) {
case seastar::http::reply::status_type::unauthorized:
case seastar::http::reply::status_type::forbidden:
case seastar::http::reply::status_type::not_found:
return retryable::no;
case seastar::http::reply::status_type::too_many_requests:
case seastar::http::reply::status_type::internal_server_error:
case seastar::http::reply::status_type::bandwidth_limit_exceeded:
case seastar::http::reply::status_type::service_unavailable:
case seastar::http::reply::status_type::request_timeout:
case seastar::http::reply::status_type::page_expired:
case seastar::http::reply::status_type::login_timeout:
case seastar::http::reply::status_type::gateway_timeout:
case seastar::http::reply::status_type::network_connect_timeout:
case seastar::http::reply::status_type::network_read_timeout:
return retryable::yes;
default:
return retryable{seastar::http::reply::classify_status(http_code) == seastar::http::reply::status_class::server_error};
}
}
retryable from_system_error(const std::system_error& system_error) {
switch (system_error.code().value()) {
case static_cast<int>(std::errc::interrupted):
case static_cast<int>(std::errc::resource_unavailable_try_again):
case static_cast<int>(std::errc::timed_out):
case static_cast<int>(std::errc::connection_aborted):
case static_cast<int>(std::errc::connection_reset):
case static_cast<int>(std::errc::connection_refused):
case static_cast<int>(std::errc::broken_pipe):
case static_cast<int>(std::errc::network_unreachable):
case static_cast<int>(std::errc::host_unreachable):
case static_cast<int>(std::errc::network_down):
case static_cast<int>(std::errc::network_reset):
case static_cast<int>(std::errc::no_buffer_space):
// GNU TLS section. Since we pack gnutls error codes in std::system_error and rethrow it as std::nested_exception we have to handle them here.
case GNUTLS_E_PREMATURE_TERMINATION:
case GNUTLS_E_AGAIN:
case GNUTLS_E_INTERRUPTED:
case GNUTLS_E_PUSH_ERROR:
case GNUTLS_E_PULL_ERROR:
case GNUTLS_E_TIMEDOUT:
case GNUTLS_E_SESSION_EOF:
case GNUTLS_E_BAD_COOKIE: // as per RFC6347 section-4.2.1 client should retry
return retryable::yes;
default:
return retryable::no;
}
}
} // namespace utils::http

View File

@@ -1,20 +0,0 @@
/*
* Copyright (C) 2026-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#pragma once
#include <seastar/http/reply.hh>
#include <seastar/util/bool_class.hh>
namespace utils::http {
using retryable = seastar::bool_class<struct is_retryable>;
retryable from_http_code(seastar::http::reply::status_type http_code);
retryable from_system_error(const std::system_error& system_error);
} // namespace utils::http

View File

@@ -13,15 +13,13 @@
#endif
#include "aws_error.hh"
#include "utils/exceptions.hh"
#include <seastar/util/log.hh>
#include <seastar/http/exception.hh>
#include <gnutls/gnutls.h>
#include <memory>
namespace aws {
using namespace utils::http;
aws_error::aws_error(aws_error_type error_type, retryable is_retryable) : _type(error_type), _is_retryable(is_retryable) {
}
@@ -132,32 +130,64 @@ aws_error aws_error::from_http_code(seastar::http::reply::status_type http_code)
}
aws_error aws_error::from_system_error(const std::system_error& system_error) {
auto is_retryable = utils::http::from_system_error(system_error);
if (is_retryable == retryable::yes) {
return {aws_error_type::NETWORK_CONNECTION, system_error.code().message(), is_retryable};
switch (system_error.code().value()) {
case static_cast<int>(std::errc::interrupted):
case static_cast<int>(std::errc::resource_unavailable_try_again):
case static_cast<int>(std::errc::timed_out):
case static_cast<int>(std::errc::connection_aborted):
case static_cast<int>(std::errc::connection_reset):
case static_cast<int>(std::errc::connection_refused):
case static_cast<int>(std::errc::broken_pipe):
case static_cast<int>(std::errc::network_unreachable):
case static_cast<int>(std::errc::host_unreachable):
case static_cast<int>(std::errc::network_down):
case static_cast<int>(std::errc::network_reset):
case static_cast<int>(std::errc::no_buffer_space):
// GNU TLS section. Since we pack gnutls error codes in std::system_error and rethrow it as std::nested_exception we have to handle them here.
case GNUTLS_E_PREMATURE_TERMINATION:
case GNUTLS_E_AGAIN:
case GNUTLS_E_INTERRUPTED:
case GNUTLS_E_PUSH_ERROR:
case GNUTLS_E_PULL_ERROR:
case GNUTLS_E_TIMEDOUT:
case GNUTLS_E_SESSION_EOF:
case GNUTLS_E_BAD_COOKIE: // as per RFC6347 section-4.2.1 client should retry
return {aws_error_type::NETWORK_CONNECTION, system_error.code().message(), retryable::yes};
default:
return {aws_error_type::UNKNOWN,
format("Non-retryable system error occurred. Message: {}, code: {}", system_error.code().message(), system_error.code().value()),
retryable::no};
}
return {aws_error_type::UNKNOWN,
format("Non-retryable system error occurred. Message: {}, code: {}", system_error.code().message(), system_error.code().value()),
is_retryable};
}
aws_error aws_error::from_exception_ptr(std::exception_ptr exception) {
return dispatch_exception<aws_error>(
std::move(exception),
[](std::exception_ptr eptr, std::string&& original_message) {
if (!original_message.empty()) {
return aws_error{aws_error_type::UNKNOWN, std::move(original_message), retryable::no};
std::string original_message;
while (exception) {
try {
std::rethrow_exception(exception);
} catch (const aws_exception& ex) {
return ex.error();
} catch (const seastar::httpd::unexpected_status_error& ex) {
return from_http_code(ex.status());
} catch (const std::system_error& ex) {
return from_system_error(ex);
} catch (const std::exception& ex) {
if (original_message.empty()) {
original_message = ex.what();
}
if (!eptr) {
return aws_error{aws_error_type::UNKNOWN, "No exception was provided to `aws_error::from_exception_ptr` function call", retryable::no};
try {
std::rethrow_if_nested(ex);
} catch (...) {
exception = std::current_exception();
continue;
}
return aws_error{
aws_error_type::UNKNOWN, seastar::format("No error message was provided, exception content: {}", eptr), retryable::no};
},
[](const aws_exception& ex) { return ex.error(); },
[](const seastar::httpd::unexpected_status_error& ex) { return from_http_code(ex.status()); },
[](const std::system_error& ex) { return from_system_error(ex); });
return aws_error{aws_error_type::UNKNOWN, std::move(original_message), retryable::no};
} catch (...) {
return aws_error{aws_error_type::UNKNOWN, seastar::format("No error message was provided, exception content: {}", std::current_exception()), retryable::no};
}
}
return aws_error{aws_error_type::UNKNOWN, "No exception was provided to `aws_error::from_exception_ptr` function call", retryable::no};
}
const aws_errors& aws_error::get_errors() {

View File

@@ -14,7 +14,6 @@
#include <string>
#include <string_view>
#include <unordered_map>
#include "utils/http_client_error_processing.hh"
namespace aws {
@@ -89,20 +88,21 @@ enum class aws_error_type : uint8_t {
};
class aws_error;
using retryable = seastar::bool_class<struct is_retryable>;
using aws_errors = std::unordered_map<std::string_view, const aws_error>;
class aws_error {
aws_error_type _type{aws_error_type::OK};
std::string _message;
utils::http::retryable _is_retryable{utils::http::retryable::no};
retryable _is_retryable{retryable::no};
public:
aws_error() = default;
aws_error(aws_error_type error_type, utils::http::retryable is_retryable);
aws_error(aws_error_type error_type, std::string&& error_message, utils::http::retryable is_retryable);
aws_error(aws_error_type error_type, retryable is_retryable);
aws_error(aws_error_type error_type, std::string&& error_message, retryable is_retryable);
[[nodiscard]] const std::string& get_error_message() const { return _message; }
[[nodiscard]] aws_error_type get_error_type() const { return _type; }
[[nodiscard]] utils::http::retryable is_retryable() const { return _is_retryable; }
[[nodiscard]] retryable is_retryable() const { return _is_retryable; }
static std::optional<aws_error> parse(seastar::sstring&& body);
static aws_error from_http_code(seastar::http::reply::status_type http_code);
static aws_error from_system_error(const std::system_error& system_error);

View File

@@ -80,13 +80,9 @@ static logging::logger s3l("s3");
// "Each part must be at least 5 MB in size, except the last part."
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
static constexpr size_t aws_minimum_part_size = 5_MiB;
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
static constexpr size_t aws_maximum_part_size = 5_GiB;
// "Part numbers can be any number from 1 to 10,000, inclusive."
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
static constexpr unsigned aws_maximum_parts_in_piece = 10'000;
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingObjects.html
static constexpr size_t aws_maximum_object_size = aws_maximum_parts_in_piece * aws_maximum_part_size;
future<> ignore_reply(const http::reply& rep, input_stream<char>&& in_) {
auto in = std::move(in_);
@@ -338,13 +334,13 @@ http::experimental::client::reply_handler client::wrap_handler(http::request& re
s3l.warn("Request failed with REQUEST_TIME_TOO_SKEWED. Machine time: {}, request timestamp: {}",
utils::aws::format_time_point(db_clock::now()),
request.get_header("x-amz-date"));
should_retry = utils::http::retryable::yes;
should_retry = aws::retryable::yes;
co_await authorize(request);
}
if (possible_error->get_error_type() == aws::aws_error_type::EXPIRED_TOKEN) {
s3l.warn("Request failed with EXPIRED_TOKEN. Resetting credentials");
_credentials = {};
should_retry = utils::http::retryable::yes;
should_retry = aws::retryable::yes;
co_await authorize(request);
}
co_await coroutine::return_exception_ptr(std::make_exception_ptr(
@@ -359,7 +355,7 @@ http::experimental::client::reply_handler client::wrap_handler(http::request& re
// We need to be able to simulate a retry in s3 tests
if (utils::get_local_injector().enter("s3_client_fail_authorization")) {
throw aws::aws_exception(
aws::aws_error{aws::aws_error_type::HTTP_UNAUTHORIZED, "EACCESS fault injected to simulate authorization failure", utils::http::retryable::no});
aws::aws_error{aws::aws_error_type::HTTP_UNAUTHORIZED, "EACCESS fault injected to simulate authorization failure", aws::retryable::no});
}
co_return co_await handler(rep, std::move(_in));
} catch (...) {
@@ -663,8 +659,6 @@ sstring parse_multipart_copy_upload_etag(sstring& body) {
class client::multipart_upload {
protected:
static constexpr size_t _max_multipart_concurrency = 16;
shared_ptr<client> _client;
sstring _object_name;
sstring _upload_id;
@@ -734,15 +728,10 @@ private:
std::exception_ptr ex;
try {
auto parts = std::views::iota(size_t{0}, (source_size + part_size - 1) / part_size);
_part_etags.resize(parts.size());
co_await max_concurrent_for_each(parts,
_max_multipart_concurrency,
[part_size, source_size, this](auto part_num) -> future<> {
auto part_offset = part_num * part_size;
auto actual_part_size = std::min(source_size - part_offset, part_size);
co_await copy_part(part_offset, actual_part_size, part_num);
});
for (size_t offset = 0; offset < source_size; offset += part_size) {
part_size = std::min(source_size - offset, part_size);
co_await copy_part(offset, part_size);
}
// Here we are going to finalize the upload and close the _bg_flushes, in case an exception is thrown the
// gate will be closed and the upload will be aborted. See below.
co_await finalize_upload();
@@ -759,7 +748,9 @@ private:
}
}
future<> copy_part(size_t offset, size_t part_size, size_t part_number) {
future<> copy_part(size_t offset, size_t part_size) {
unsigned part_number = _part_etags.size();
_part_etags.emplace_back();
auto req = http::request::make("PUT", _client->_host, _object_name);
req._headers["x-amz-copy-source"] = _source_object;
auto range = format("bytes={}-{}", offset, offset + part_size - 1);
@@ -769,7 +760,11 @@ private:
req.set_query_param("partNumber", to_sstring(part_number + 1));
req.set_query_param("uploadId", _upload_id);
co_await _client->make_request(std::move(req),[this, part_number, start = s3_clock::now()](group_client& gc, const http::reply& reply, input_stream<char>&& in) -> future<> {
// upload the parts in the background for better throughput
auto gh = _bg_flushes.hold();
// Ignoring the result of make_request() because we don't want to block and it is safe since we have a gate we are going to wait on and all argument are
// captured by value or moved into the fiber
std::ignore = _client->make_request(std::move(req),[this, part_number, start = s3_clock::now()](group_client& gc, const http::reply& reply, input_stream<char>&& in) -> future<> {
auto _in = std::move(in);
auto body = co_await util::read_entire_stream_contiguous(_in);
auto etag = parse_multipart_copy_upload_etag(body);
@@ -781,7 +776,8 @@ private:
},http::reply::status_type::ok, _as)
.handle_exception([this, part_number](auto ex) {
s3l.warn("Failed to upload part {}, upload id {}. Reason: {}", part_number, _upload_id, ex);
});
})
.finally([gh = std::move(gh)] {});
co_return;
}
@@ -1289,7 +1285,7 @@ class client::chunked_download_source final : public seastar::data_source_impl {
while (_buffers_size < _max_buffers_size && !_is_finished) {
utils::get_local_injector().inject("kill_s3_inflight_req", [] {
// Inject non-retryable error to emulate source failure
throw aws::aws_exception(aws::aws_error(aws::aws_error_type::RESOURCE_NOT_FOUND, "Injected ResourceNotFound", utils::http::retryable::no));
throw aws::aws_exception(aws::aws_error(aws::aws_error_type::RESOURCE_NOT_FOUND, "Injected ResourceNotFound", aws::retryable::no));
});
s3l.trace("Fiber for object '{}' will try to read within range {}", _object_name, _range);
@@ -1533,11 +1529,13 @@ class client::do_upload_file : private multipart_upload {
}
}
future<> upload_part(file f, uint64_t offset, uint64_t part_size, uint64_t part_number) {
future<> upload_part(file f, uint64_t offset, uint64_t part_size) {
// upload a part in a multipart upload, see
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
auto mem_units = co_await _client->claim_memory(_transmit_size, _as);
unsigned part_number = _part_etags.size();
_part_etags.emplace_back();
auto req = http::request::make("PUT", _client->_host, _object_name);
req._headers["Content-Length"] = to_sstring(part_size);
req.set_query_param("partNumber", to_sstring(part_number + 1));
@@ -1548,7 +1546,9 @@ class client::do_upload_file : private multipart_upload {
auto output = std::move(out_);
return copy_to(std::move(input), std::move(output), _transmit_size, progress);
});
co_await _client->make_request(std::move(req), [this, part_size, part_number, start = s3_clock::now()] (group_client& gc, const http::reply& reply, input_stream<char>&& in_) mutable -> future<> {
// upload the parts in the background for better throughput
auto gh = _bg_flushes.hold();
std::ignore = _client->make_request(std::move(req), [this, part_size, part_number, start = s3_clock::now()] (group_client& gc, const http::reply& reply, input_stream<char>&& in_) mutable -> future<> {
auto etag = reply.get_header("ETag");
s3l.trace("uploaded {} part data -> etag = {} (upload id {})", part_number, etag, _upload_id);
_part_etags[part_number] = std::move(etag);
@@ -1556,7 +1556,32 @@ class client::do_upload_file : private multipart_upload {
return make_ready_future();
}, http::reply::status_type::ok, _as).handle_exception([this, part_number] (auto ex) {
s3l.warn("couldn't upload part {}: {} (upload id {})", part_number, ex, _upload_id);
});
}).finally([gh = std::move(gh)] {});
}
// returns pair<num_of_parts, part_size>
static std::pair<unsigned, size_t> calc_part_size(size_t total_size, size_t part_size) {
if (part_size > 0) {
if (part_size < aws_minimum_part_size) {
on_internal_error(s3l, fmt::format("part_size too large: {} < {}", part_size, aws_minimum_part_size));
}
const size_t num_parts = div_ceil(total_size, part_size);
if (num_parts > aws_maximum_parts_in_piece) {
on_internal_error(s3l, fmt::format("too many parts: {} > {}", num_parts, aws_maximum_parts_in_piece));
}
return {num_parts, part_size};
}
// if part_size is 0, this means the caller leaves it to us to decide
// the part_size. to be more reliance, say, we don't have to re-upload
// a giant chunk of buffer if a certain part fails to upload, we prefer
// small parts, let's make it a multiple of MiB.
part_size = div_ceil(total_size / aws_maximum_parts_in_piece, 1_MiB);
// The default part size for multipart upload is set to 50MiB.
// This value was determined empirically by running `perf_s3_client` with various part sizes to find the optimal one.
static constexpr size_t default_part_size = 50_MiB;
part_size = std::max(part_size, default_part_size);
return {div_ceil(total_size, part_size), part_size};
}
future<> multi_part_upload(file&& f, uint64_t total_size, size_t part_size) {
@@ -1564,14 +1589,12 @@ class client::do_upload_file : private multipart_upload {
std::exception_ptr ex;
try {
co_await max_concurrent_for_each(std::views::iota(size_t{0}, (total_size + part_size - 1) / part_size),
_max_multipart_concurrency,
[part_size, total_size, this, f = file{f}](auto part_num) -> future<> {
auto part_offset = part_num * part_size;
auto actual_part_size = std::min(total_size - part_offset, part_size);
s3l.trace("upload_part: {}~{}/{}", part_offset, actual_part_size, total_size);
co_await upload_part(f, part_offset, actual_part_size, part_num);
});
for (size_t offset = 0; offset < total_size; offset += part_size) {
part_size = std::min(total_size - offset, part_size);
s3l.trace("upload_part: {}~{}/{}", offset, part_size, total_size);
co_await upload_part(file{f}, offset, part_size);
}
co_await finalize_upload();
} catch (...) {
ex = std::current_exception();
@@ -1629,7 +1652,7 @@ public:
// parallel to improve throughput
if (file_size > aws_minimum_part_size) {
auto [num_parts, part_size] = calc_part_size(file_size, _part_size);
_part_etags.resize(num_parts);
_part_etags.reserve(num_parts);
co_await multi_part_upload(std::move(f), file_size, part_size);
} else {
// single part upload
@@ -1926,34 +1949,4 @@ future<> client::bucket_lister::close() noexcept {
}
}
// returns pair<num_of_parts, part_size>
std::pair<unsigned, size_t> calc_part_size(size_t total_size, size_t part_size) {
if (total_size > aws_maximum_object_size) {
on_internal_error(s3l, fmt::format("object size too large: {} is larger than maximum S3 object size: {}", total_size, aws_maximum_object_size));
}
if (part_size > 0) {
if (part_size > aws_maximum_part_size) {
on_internal_error(s3l, fmt::format("part_size too large: {} is larger than maximum part size: {}", part_size, aws_maximum_part_size));
}
if (part_size < aws_minimum_part_size) {
on_internal_error(s3l, fmt::format("part_size too small: {} is smaller than minimum part size: {}", part_size, aws_minimum_part_size));
}
const size_t num_parts = div_ceil(total_size, part_size);
if (num_parts > aws_maximum_parts_in_piece) {
on_internal_error(s3l, fmt::format("too many parts: {} > {}", num_parts, aws_maximum_parts_in_piece));
}
return {num_parts, part_size};
}
// if part_size is 0, this means the caller leaves it to us to decide the part_size. The default part size for multipart upload is set to 50MiB. This
// value was determined empirically by running `perf_s3_client` with various part sizes to find the optimal one.
static constexpr size_t default_part_size = 50_MiB;
const size_t num_parts = div_ceil(total_size, default_part_size);
if (num_parts <= aws_maximum_parts_in_piece) {
return {num_parts, default_part_size};
}
part_size = align_up(div_ceil(total_size, aws_maximum_parts_in_piece), 1_MiB);
return {div_ceil(total_size, part_size), part_size};
}
} // s3 namespace

View File

@@ -251,8 +251,6 @@ public:
future<> close();
};
std::pair<unsigned, size_t> calc_part_size(size_t total_size, size_t part_size);
} // namespace s3
template <>

View File

@@ -39,7 +39,7 @@ seastar::future<bool> default_aws_retry_strategy::should_retry(std::exception_pt
co_return false;
}
auto err = aws_error::from_exception_ptr(error);
bool should_retry = err.is_retryable() == utils::http::retryable::yes;
bool should_retry = err.is_retryable() == retryable::yes;
if (should_retry) {
rs_logger.debug("AWS HTTP client request failed. Reason: {}. Retry# {}", err.get_error_message(), attempted_retries);
co_await sleep_before_retry(attempted_retries);

View File

@@ -55,13 +55,9 @@ private:
future<connected_socket> connect() {
auto addr = socket_address(_endpoint.ip, _endpoint.port);
if (_creds) {
auto socket = co_await tls::connect(_creds, addr, tls::tls_options{.server_name = _endpoint.host});
// tls::connect() only performs the TCP handshake — the TLS handshake is deferred until the first I/O operation.
// Force the TLS handshake to happen here so that the connection timeout applies to it.
co_await tls::check_session_is_resumed(socket);
co_return socket;
return tls::connect(_creds, addr, tls::tls_options{.server_name = _endpoint.host});
}
co_return co_await seastar::connect(addr, {}, transport::TCP);
return seastar::connect(addr, {}, transport::TCP);
}
std::chrono::milliseconds timeout() const {

View File

@@ -32,10 +32,8 @@ seastar::future<seastar::shared_ptr<seastar::tls::certificate_credentials>> trus
if (self._credentials) {
b.rebuild(*self._credentials);
}
self._reload_count++;
return make_ready_future();
});
_reload_count++;
}
});
} else {

View File

@@ -29,10 +29,6 @@ public:
seastar::future<seastar::shared_ptr<seastar::tls::certificate_credentials>> get();
seastar::future<> stop();
unsigned reload_count() const {
return _reload_count;
}
private:
seastar::future<seastar::tls::credentials_builder> create_builder() const;
@@ -41,7 +37,6 @@ private:
seastar::shared_ptr<seastar::tls::certificate_credentials> _credentials;
invoke_on_others_type _invoke_on_others;
seastar::gate _gate;
unsigned _reload_count = 0;
};
} // namespace vector_search

View File

@@ -414,8 +414,4 @@ auto vector_store_client_tester::resolve_hostname(vector_store_client& vsc, abor
co_return ret;
}
unsigned vector_store_client_tester::truststore_reload_count(vector_store_client& vsc) {
return vsc._impl->_truststore.reload_count();
}
} // namespace vector_search

View File

@@ -89,7 +89,6 @@ struct vector_store_client_tester {
static void set_dns_resolver(vector_store_client& vsc, std::function<future<std::vector<net::inet_address>>(sstring const&)> resolver);
static void trigger_dns_resolver(vector_store_client& vsc);
static auto resolve_hostname(vector_store_client& vsc, abort_source& as) -> future<std::vector<net::inet_address>>;
static unsigned truststore_reload_count(vector_store_client& vsc);
};
} // namespace vector_search