Files
scylladb/test/cluster/test_encryption.py
Botond Dénes 9770a4c081 test/cluster/test_encryption.py: use single-partition reads in read_verify_workload()
Replace the range scan in read_verify_workload() with individual
single-partition queries, using the keys returned by
prepare_write_workload() instead of hard-coding them.

The range scan was previously observed to time out in debug mode after
a hard cluster restart. Single-partition reads are lighter on the
cluster and less likely to time out under load.

The new verification is also stricter: instead of merely checking that
the expected number of rows is returned, it verifies that each written
key is individually readable, catching any data-loss or key-identity
mismatch that the old count-only check would have missed.

This is the second attemp at stabilizing this test, after the recent
854c374ebf. That fix made sure that the
cluster has converged on topology and nodes see each other before running
the verify workload.

Fixes: SCYLLADB-1331

Closes scylladb/scylladb#29313
2026-04-12 00:38:20 +03:00

553 lines
26 KiB
Python

#
# Copyright (C) 2024-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
import asyncio
import contextlib
import tempfile
import time
import glob
import os
import itertools
import logging
import subprocess
import json
import uuid
from test.pylib.manager_client import ManagerClient, ServerInfo
from test.pylib.util import wait_for_cql_and_get_hosts
from test.pylib.tablets import get_all_tablet_replicas
from test.cqlpy import nodetool
from test.pylib.encryption_provider import KeyProviderFactory, KeyProvider, make_key_provider_factory, KMSKeyProviderFactory, LocalFileSystemKeyProviderFactory
from test.cluster.util import new_test_keyspace, new_test_table
from test.cluster.dtest.tools.assertions import assert_one
from typing import Callable, Coroutine
from cassandra import ConsistencyLevel
from cassandra.cluster import Session as CassandraSession, NoHostAvailable
from cassandra.protocol import ConfigurationException
from cassandra.auth import PlainTextAuthProvider
import pytest
logger = logging.getLogger(__name__)
@pytest.fixture(scope="function")
def workdir():
# pylint: disable=missing-function-docstring
with tempfile.TemporaryDirectory() as tmp_dir:
yield tmp_dir
async def test_file_streaming_respects_encryption(manager: ManagerClient, workdir):
# pylint: disable=missing-function-docstring
cfg = {
'tablets_mode_for_new_keyspaces': 'enabled',
}
cmdline = ['--smp=1']
servers = []
servers.append(await manager.server_add(config=cfg, cmdline=cmdline))
await manager.disable_tablet_balancing()
cql = manager.cql
await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60)
cql.execute("CREATE KEYSPACE ks WITH REPLICATION = {'class' : 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 1};")
cql.execute(f"""CREATE TABLE ks.t(pk text primary key) WITH scylla_encryption_options = {{
'cipher_algorithm' : 'AES/ECB/PKCS5Padding',
'secret_key_strength' : 128,
'key_provider': 'LocalFileSystemKeyProviderFactory',
'secret_key_file': '{workdir}/data_encryption_key'
}}""")
cql.execute("INSERT INTO ks.t(pk) VALUES('alamakota')")
servers.append(await manager.server_add(config=cfg, cmdline=cmdline))
tablet_replicas = await get_all_tablet_replicas(manager, servers[0], 'ks', 't')
host_ids = await asyncio.gather(*[manager.get_host_id(s.server_id) for s in servers])
await manager.api.move_tablet(servers[0].ip_addr, "ks", "t", host_ids[0], 0, host_ids[1], 0, tablet_replicas[0][0])
rows = cql.execute("SELECT * from ks.t WHERE pk = 'alamakota'")
assert len(list(rows)) == 1
def filter_ciphers(kp: KeyProviderFactory, ciphers=dict[str, list[int]]) -> list[tuple[str, len]]:
"""filter out ciphers based on provider caps"""
if not ciphers:
return [(None, None)]
return [(cipher, length) for cipher in ciphers for length in ciphers[cipher]
if kp.supported_cipher(cipher, length)]
async def create_ks(manager: ManagerClient, replication_factor: int=1):
"""create test keyspace"""
return new_test_keyspace(manager,
opts="with replication = {'class': 'NetworkTopologyStrategy', "
f"'replication_factor': {replication_factor}}}"
)
async def create_encrypted_cf(manager: ManagerClient, ks: str,
columns: str=None,
cipher_algorithm=None,
secret_key_strength=None,
compression=None,
additional_options=None,
):
"""create test cf"""
if additional_options is None:
additional_options = {}
if columns is None:
columns = "key text PRIMARY KEY, c1 text, c2 text"
options = {}
if additional_options:
options.update(additional_options)
if cipher_algorithm:
options.update({"cipher_algorithm": cipher_algorithm})
if secret_key_strength:
options.update({"secret_key_strength": secret_key_strength})
extra = f'WITH scylla_encryption_options={options}'
if compression is not None:
extra = f"{extra} AND compression = {{ 'sstable_compression': '{compression}Compressor' }}"
return new_test_table(manager, ks, columns, extra)
async def prepare_write_workload(cql: CassandraSession, table_name, flush=True, n: int = None) -> list[str]:
"""write some data, returns list of written partition keys"""
key_ids = list(range(n if n else 100))
c1_values = ['value1']
c2_values = ['value2']
statement = cql.prepare(f"INSERT INTO {table_name} (key, c1, c2) VALUES (?, ?, ?)")
statement.consistency_level = ConsistencyLevel.ALL
keys = [f"k{x}" for x in key_ids]
await asyncio.gather(*[cql.run_async(statement, params) for params in
list(map(lambda x, y, z: [x, y, z], keys,
itertools.cycle(c1_values),
itertools.cycle(c2_values)))]
)
if flush:
nodetool.flush(cql, table_name)
return keys
async def read_verify_workload(cql: CassandraSession, table_name: str, keys: list[str]):
"""check written data using single-partition queries"""
statement = cql.prepare(f"SELECT c1, c2 FROM {table_name} WHERE key = ?")
rows = await asyncio.gather(*[cql.run_async(statement, [key]) for key in keys])
for key, result in zip(keys, rows):
assert len(list(result)) == 1, f"Expected 1 row for key={key}, got {len(list(result))}"
async def _smoke_test(manager: ManagerClient, key_provider: KeyProviderFactory,
ciphers: dict[str, list[int]], compression: str = None,
exception_handler: Callable[[Exception,str,str], None] = None,
options: dict = {},
num_servers: int = 1,
restart: Callable[[ManagerClient, list[ServerInfo], list[str]], Coroutine[None, None, None]] = None):
"""helper to create cluster, cfs, data and verify it after restart"""
cfg = options | key_provider.configuration_parameters()
servers: list[ServerInfo] = await manager.servers_add(servers_num = num_servers, config=cfg, auto_rack_dc='dc1')
cql = manager.cql
await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60)
async with await create_ks(manager, replication_factor = num_servers) as ks:
# to reduce test time, create one cf for every alg/len combo we test.
# avoids rebooting cluster for every check.
async with contextlib.AsyncExitStack() as stack:
cfs = []
for cipher_algorithm, secret_key_strength in filter_ciphers(key_provider, ciphers):
try:
additional_options = key_provider.additional_cf_options()
table_name = await stack.enter_async_context(
await create_encrypted_cf(manager, ks, cipher_algorithm=cipher_algorithm,
secret_key_strength=secret_key_strength,
compression=compression,
additional_options=additional_options
))
keys = await prepare_write_workload(cql, table_name=table_name)
cfs.append((table_name, keys))
except Exception as e:
if exception_handler:
exception_handler(e, cipher_algorithm, secret_key_strength)
continue
raise e
# restart the cluster
if restart:
await restart(manager, servers, [table_name for table_name, _ in cfs])
cql, _ = await manager.get_ready_cql(servers)
else:
await manager.rolling_restart(servers)
for table_name, keys in cfs:
await read_verify_workload(cql, table_name=table_name, keys=keys)
# default: 'AES/CBC/PKCS5Padding', length 128
supported_cipher_algorithms = {
"": [],
"AES/CBC/PKCS5Padding": [128, 192, 256], # 192 has problem
"AES/CBC": [128, 192, 256], # 192 has problem
"AES": [128, 192, 256], # 192 has problem
"AES/ECB/PKCS5Padding": [128, 192, 256],
"AES/ECB": [128, 192, 256],
# legacy algorithms, not supported in openssl 3.x
# "DES/CBC/PKCS5Padding": [56],
# "DES/CBC": [56],
# "DES": [56],
# 'DESede/CBC/PKCS5Padding': [112, 168], # not support by Scylla, supported by DSE
# 'Blowfish/CBC/PKCS5Padding': [32, 448], # not support by Scylla, supported by DSE
# "RC2/CBC/PKCS5Padding": [80, 128], # [40, 80, 128] # 40 to 128
# "RC2/CBC": [80, 128], # [40, 80, 128] # 40 to 128
# "RC2": [80, 128], # [40, 80, 128] # 40 to 128
}
async def test_supported_cipher_algorithms(manager, key_provider):
"""Checks our providers can operate the algos we claim"""
errors = []
def handler(e, cipher, length):
logger.debug(str(e))
errors.append(f"Cipher '{cipher}', length {length}, "
f"key provider {key_provider}' failed. Error {e}")
await _smoke_test(manager, key_provider=key_provider,
ciphers=supported_cipher_algorithms,
exception_handler=handler)
assert not errors, errors
async def test_wrong_cipher_algorithm(manager, key_provider):
"""Checks we reject non-valid cipher parameters/algos"""
errors = []
expected_errors = []
broken_ciphers = {c: l for oc in supported_cipher_algorithms if oc
for l in [supported_cipher_algorithms[oc][:1]]
for a in ["Abc/", "/Abc", "Abc"] for c in [oc + a, a + oc]}
def handler(e, cipher, length):
try:
raise e
except (NoHostAvailable, ConfigurationException) as exc_details:
error_str = str(exc_details)
logger.debug(error_str)
assert (
f"Invalid algorithm string: {cipher}" in error_str
or ("Invalid algorithm" in error_str and cipher in error_str)
or "Could not write key file" in error_str
or ("[Server error] message=" in error_str and "abc" in error_str)
or "non-supported padding option" in error_str
or "routines::unsupported" in error_str
), error_str
expected_errors.append(e)
except Exception as exc:
errors.append(f"Unexpected exception: {exc}. Encryption options: "
f"'key_provider': '{key_provider}', 'cipher_algorithm': "
f"'{cipher}', 'secret_key_strength': {length}")
logger.debug(errors[-1])
await _smoke_test(manager, key_provider=key_provider,
ciphers=broken_ciphers,
exception_handler=handler)
# TODO: Uncomment next line when issue https://github.com/scylladb/scylla-enterprise/issues/1973 will be resolve
# assert not unexpected_success, "Negative tests succeeded unexpectedly: %s" % '\n'.join(unexpected_success)
assert not errors, errors
assert len(expected_errors) == len(broken_ciphers), expected_errors
@pytest.mark.parametrize(argnames="compression", argvalues=("LZ4", "Snappy", "Deflate"))
async def test_encryption_table_compression(manager, tmpdir, compression, scylla_binary):
"""Test compression + ear"""
logger.debug("---- Test with compression: %s -----", compression)
async with make_key_provider_factory(KeyProvider.local, tmpdir, scylla_binary) as key_provider:
await _smoke_test(manager, key_provider,
ciphers={"AES/CBC/PKCS5Padding": [128]},
compression=compression)
async def test_reboot(manager, key_provider):
"""Tests SIGKILL restart of 3-node cluster"""
async def restart(manager: ManagerClient, servers: list[ServerInfo], table_names: list[str]):
# pylint: disable=unused-argument
for s in servers:
await manager.server_stop(s.server_id)
await manager.server_start(s.server_id)
num_servers = 3
# Replicated provider cannot handle hard reboot of cluster safely.
# We can't be sure keys are propagated such that they are reachable
# for restarted nodes. However, using single node the test can run,
# though obviously somewhat lamely.
if key_provider.key_provider == KeyProvider.replicated:
num_servers = 1
options = {"commitlog_sync": "batch"}
await _smoke_test(manager, key_provider=key_provider,
ciphers={"AES/CBC/PKCS5Padding": [128]},
options=options,
num_servers=num_servers,
restart=restart)
def get_sstables(node_workdir, ks:str, table:str, sst_type = None):
"""Glob sstable files (of type) at node_workdir"""
base_pattern = os.path.join(node_workdir, "data", f"{ks}*", f"{table}-*", f"*{'-' + sst_type if sst_type else ''}.db")
sstables = glob.glob(base_pattern)
return sstables
async def get_sstable_metadata(manager: ManagerClient, server: ServerInfo, keyspace: str, column_family: str):
"""Load scylla metadata component sstables for server and cf"""
scylla_path = await manager.server_get_exe(server.server_id)
node_workdir = await manager.server_get_workdir(server.server_id)
scylla_sstables = get_sstables(node_workdir, keyspace, column_family, 'Scylla')
res = subprocess.check_output([scylla_path, "sstable", "dump-scylla-metadata",
"--scylla-yaml-file",
os.path.join(node_workdir, "conf", "scylla.yaml"),
"--sstables"] + scylla_sstables,
stderr=subprocess.PIPE)
scylla_metadata = json.loads(res.decode('utf-8', 'ignore'))
return scylla_metadata
async def validate_sstables_encryption(manager: ManagerClient, server: ServerInfo, table_name: str, encrypted:bool, expected_data=None):
"""Verify sstables for table encrypted or not"""
keyspace, column_family = table_name.split(".")
scylla_path = await manager.server_get_exe(server.server_id)
with nodetool.no_autocompaction_context(manager.cql, table_name):
scylla_metadata = await get_sstable_metadata(manager, server, keyspace, column_family)
logger.debug("validate_sstables_encrypted(): scylla_metadata=%s", scylla_metadata)
encrypt_opts = ["scylla_encryption_options" in metadata.get("extension_attributes", {})
for _, metadata in scylla_metadata['sstables'].items()]
assert encrypt_opts, encrypt_opts # should not be empty
if encrypted:
assert all(encrypt_opts), encrypt_opts
else:
assert not any(encrypt_opts), encrypt_opts
if expected_data is not None:
node_workdir = await manager.server_get_workdir(server.server_id)
sstables = get_sstables(node_workdir, keyspace, column_family)
res = subprocess.check_output([scylla_path, "sstable", "query",
"--scylla-yaml-file",
os.path.join(node_workdir, "conf", "scylla.yaml"),
"--output-format", "json", "--sstables"] + sstables,
stderr=subprocess.PIPE)
scylla_data = json.loads(res.decode('utf-8', 'ignore'))
actual_data = [list(r.values()) for r in scylla_data]
assert actual_data == expected_data
async def test_alter(manager, key_provider):
"""Tests altering encrypted CF:s and verify sstable data"""
async def restart(manager: ManagerClient, servers: list[ServerInfo], table_names: list[str]):
cql = manager.cql
expected_data = [list(row._asdict().values())
for row in cql.execute(f"SELECT * FROM {table_names[0]}")]
logger.info("expected_data=%s", expected_data)
# we cannot use tools like scylla sstable with replicated provider
# and read encrypted tables.
if key_provider.key_provider != KeyProvider.replicated:
await validate_sstables_encryption(manager, servers[0],
table_names[0], True,
expected_data=expected_data)
# disable encryption
cql = manager.cql
cql.execute(f"ALTER TABLE {table_names[0]} with "
"scylla_encryption_options={'key_provider':'none'}")
table_desc = cql.execute(f"DESC {table_names[0]}").one().create_statement
assert "key_provider" not in table_desc, f"key_provider isn't disabled, schema:\n {table_desc}"
await manager.api.keyspace_upgrade_sstables(servers[0].ip_addr, table_names[0].split(".")[0])
await validate_sstables_encryption(manager, servers[0],
table_names[0], False,
expected_data=expected_data)
await read_verify_workload(cql, table_name=table_names[0], keys=[row[0] for row in expected_data])
# enable encryption again
options = key_provider.additional_cf_options()
cql.execute(f"ALTER TABLE {table_names[0]} with scylla_encryption_options={options}")
await manager.api.keyspace_upgrade_sstables(servers[0].ip_addr, table_names[0].split(".")[0])
table_desc = cql.execute(f"DESC {table_names[0]}").one().create_statement
assert options["key_provider"] in table_desc, f"key_provider set, schema:\n {table_desc}"
await manager.rolling_restart(servers)
await _smoke_test(manager, key_provider=key_provider,
ciphers={"AES/CBC/PKCS5Padding": [128]},
restart=restart)
async def test_per_table_master_key(manager: ManagerClient, tmpdir):
"""Test per table KMS master key"""
class MultiAliasKMSProvider (KMSKeyProviderFactory):
"""Special KMS using different master keys for each table"""
def __init__(self, tmpdir):
super(MultiAliasKMSProvider, self).__init__(tmpdir)
self.key_count: int = 0
self.key_ids: list = []
self.aliases: list = []
def additional_cf_options(self):
alias_name = f"alias/Scylla-test-{self.key_count}"
key_id = self.create_master_key(alias_name=alias_name)
self.key_ids.append(key_id)
self.aliases.append(alias_name)
self.key_count += 1
return super().additional_cf_options() | {"master_key": alias_name}
async with MultiAliasKMSProvider(tmpdir) as kp:
async def restart(manager: ManagerClient, servers: list[ServerInfo],
table_names: list[str]):
await manager.rolling_restart(servers)
i = 0
for table_name in table_names:
keyspace, column_family = table_name.split(".")
await validate_sstables_encryption(manager, servers[0],
table_name, True)
with nodetool.no_autocompaction_context(manager.cql, table_name):
scylla_metadata = await get_sstable_metadata(manager, servers[0],
keyspace, column_family)
table_key_ids = [metadata.get("extension_attributes", {})["scylla_key_id"]
for _, metadata in scylla_metadata['sstables'].items()]
key_id = kp.key_ids[i]
i = i + 1
# AWS KMS key ids are encoded as encrypting key + encrypted data, thus
# the ID we got when creating the key earlier should be visible in
# the metadata identifier
assert all([key_id in table_key_id for table_key_id in table_key_ids])
await _smoke_test(manager, kp,
ciphers={"AES/CBC/PKCS5Padding": [128, 256]},
restart=restart)
async def test_non_existant_table_master_key(manager: ManagerClient, tmpdir):
"""Test we fail properly if using a non-existant master key"""
class NoSuchKeyKMSProvider (KMSKeyProviderFactory):
"""Special KMS using nonexisting master key"""
def additional_cf_options(self):
return super().additional_cf_options() | {"master_key": "alias/NoSuchKey"}
async with NoSuchKeyKMSProvider(tmpdir) as kp:
with pytest.raises(Exception):
await _smoke_test(manager, kp, ciphers={"AES/CBC/PKCS5Padding": [128]})
async def test_system_auth_encryption(manager: ManagerClient, tmpdir):
cfg = {"authenticator": "org.apache.cassandra.auth.PasswordAuthenticator",
"authorizer": "org.apache.cassandra.auth.CassandraAuthorizer",
"commitlog_sync": "batch" }
servers: list[ServerInfo] = await manager.servers_add(servers_num = 1, config=cfg,
driver_connect_opts={'auth_provider': PlainTextAuthProvider(username='cassandra', password='cassandra')})
cql = manager.cql
await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60)
async def grep_database_files(pattern: str, path: str, files: str, expect:bool):
pattern_found_counter = 0
pbytes = pattern.encode("utf-8")
for server in servers:
node_workdir = await manager.server_get_workdir(server.server_id)
dirname = os.path.join(node_workdir, path)
file_paths = glob.glob(os.path.join(dirname, files), recursive=True)
file_paths = [f for f in file_paths if os.path.isfile(f) and not os.path.islink(f)]
for file_path in file_paths:
try:
with open(file_path, 'rb') as f:
data = f.read()
if pbytes in data:
pattern_found_counter += 1
logger.debug("Pattern '%s' found in %s", pattern, file_path)
except FileNotFoundError:
pass # assume just compacted away
if expect:
assert pattern_found_counter > 0
else:
assert pattern_found_counter == 0
async def verify_system_info(expect: bool):
user = f"user_{str(uuid.uuid4())}".replace('-','_')
pwd = f"pwd_{str(uuid.uuid4())}"
cql.execute(f"CREATE USER {user} WITH PASSWORD '{pwd}' NOSUPERUSER")
assert_one(cql, f"LIST ROLES of {user}", [user, False, True, {}])
logger.debug("Verify PART 1: check commitlogs -------------")
await grep_database_files(pwd, "commitlog", "**/*.log", False)
await grep_database_files(user, "commitlog", "**/*.log", expect)
salted_hash = None
system_auth = None
for ks in ['system', 'system_auth_v2', 'system_auth']:
try:
# We could have looked for any role/salted_hash pair, but we
# already know a role "cassandra" exists (we just used it to
# connect to CQL!), so let's just use that role.
salted_hash = cql.execute(f"SELECT salted_hash FROM {ks}.roles WHERE role = '{user}'").one().salted_hash
system_auth = ks
break
except:
pass
assert salted_hash is not None
assert system_auth is not None
await grep_database_files(salted_hash, "commitlog", "**/*.log", expect)
rand_comment = f"comment_{str(uuid.uuid4())}"
async with await create_ks(manager) as ks:
async with new_test_table(manager, ks, "key text PRIMARY KEY, c1 text, c2 text") as table:
cql.execute(f"ALTER TABLE {table} WITH comment = '{rand_comment}'")
await grep_database_files(rand_comment, "commitlog/schema", "**/*.log", expect)
# Note: original test did greping in sstables. This does no longer work
# since all system tables are compressed, and thus binary greping will
# not work. We could do scylla sstable dump-data and grep in the json,
# but this is somewhat pointless as this would, if it handles it, just
# decrypt the info from the sstable, thus we can't really verify anything.
# We could maybe check that the expected system tables are in fact encrypted,
# though this is more a promise than guarantee... Also, the only tables
# encrypted are paxos and batchlog -> pointless
await verify_system_info(True) # not encrypted
cfg = {"system_info_encryption": {
"enabled": True,
"key_provider": "LocalFileSystemKeyProviderFactory"},
"system_key_directory": os.path.join(tmpdir, "resources/system_keys")
}
for server in servers:
await manager.server_update_config(server.server_id, config_options=cfg)
await manager.server_restart(server.server_id)
await manager.rolling_restart(servers)
await verify_system_info(False) # should not see stuff now
async def test_system_encryption_reboot(manager: ManagerClient, tmpdir):
"""Tests SIGKILL restart of encrypted node"""
async def restart(manager: ManagerClient, servers: list[ServerInfo], table_names: list[str]):
# pylint: disable=unused-argument
for s in servers:
await manager.server_stop(s.server_id)
await manager.server_start(s.server_id)
options = {"commitlog_sync": "batch",
"system_info_encryption": {
"enabled": True,
"key_provider": "LocalFileSystemKeyProviderFactory"
}
}
async with LocalFileSystemKeyProviderFactory(tmpdir) as kp:
await _smoke_test(manager, key_provider=kp,
ciphers={"AES/CBC/PKCS5Padding": [128]},
options=options,
restart=restart)