Files
scylladb/test/pylib/encryption_provider.py
Andrei Chekun 99234f0a83 test.py: replace SCYLLA env var with build_mode fixture
Replace direct usage of SCYLLA environment variable with the build_mode
pytest fixture and path_to helper function. This makes tests more
flexible and consistent with the test framework. Also this allows to use
tests with xdist, where environment variable can be left in the master
process and will not be set in the workers
Add using the fixture to get the scylla binary from the suite, this will
align with getting relocatable Scylla exe.
2026-02-24 09:48:38 +01:00

310 lines
12 KiB
Python

#
# Copyright (C) 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
import os
import asyncio
import logging
import threading
from enum import Enum
from functools import cached_property
from test import TEST_DIR
from test.pylib.dockerized_service import DockerizedServer
from test.pylib.kmip_wrapper import KMIPServerWrapper
from test.pylib.azure_vault_server_mock import MockAzureVaultServer
import boto3
class KeyProvider(Enum):
"""Enumeration of key providers in scylla"""
# pylint: disable=invalid-name
local = "LocalFileSystemKeyProviderFactory"
replicated = "ReplicatedKeyProviderFactory"
kmip = "KmipKeyProviderFactory"
kms = "KmsKeyProviderFactory"
azure = "AzureKeyProviderFactory"
class KeyProviderFactory:
"""Base class for provider factories"""
def __init__(self, key_provider : KeyProvider, tmpdir):
self.key_provider = key_provider
self.system_key_location = os.path.join(tmpdir, "resources/system_keys")
async def __aenter__(self):
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
pass
def supported_cipher(self, cipher_algorithm, secret_key_strength):
# pylint: disable=unused-argument
"""Is cipher type supported by provider"""
return True
def require_restart(self):
"""Must servers be restarted on keyspace parameter changes?"""
return False
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
return {"system_key_directory": self.system_key_location}
def additional_cf_options(self) -> dict[str, str]:
# pylint: disable=unused-argument
"""keyspace options"""
if self.key_provider:
return {"key_provider": self.key_provider.value}
return {}
class LocalFileSystemKeyProviderFactory(KeyProviderFactory):
"""LocalFileSystemKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(LocalFileSystemKeyProviderFactory, self).__init__(KeyProvider.local, tmpdir)
self.secret_file = os.path.join(tmpdir, "test/node1/conf/data_encryption_keys")
def additional_cf_options(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
return super().additional_cf_options() | {"secret_key_file": self.secret_file}
class ReplicatedKeyProviderFactory(KeyProviderFactory):
"""ReplicatedKeyProviderFactory proxy"""
def __init__(self, tmpdir, scylla_exe):
super(ReplicatedKeyProviderFactory, self).__init__(KeyProvider.replicated, tmpdir)
self.system_key_file_name = "system_key"
self.scylla_exe = scylla_exe
async def __aenter__(self):
await super().__aenter__()
args = ["local-file-key-generator", "generate", os.path.join(self.system_key_location, self.system_key_file_name)]
proc = await asyncio.create_subprocess_exec(self.scylla_exe, *args, stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f'Could not generate system key: {stderr.decode()}')
return self
def additional_cf_options(self):
return super().additional_cf_options() | {"system_key": self.system_key_file_name}
class KmipKeyProviderFactory(KeyProviderFactory):
"""KmipKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(KmipKeyProviderFactory, self).__init__(KeyProvider.kmip, tmpdir)
self.tmpdir = tmpdir
self.kmip_server_wrapper = None
self.kmip_host = "kmip_test"
self.kmip_port = 0
self.kmip_server = None
self.kmip_thread = None
resourcedir = os.path.join(TEST_DIR, "pylib/resources")
# TODO: we really need certs generated on test run
self.certs = { "certfile": os.path.join(resourcedir, "scylla.crt"),
"keyfile": os.path.join(resourcedir, "scylla.key"),
"truststore": os.path.join(resourcedir, "scylla.crt")
}
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"hosts": "127.0.0.1:" + str(self.kmip_port),
"certificate": self.certs["certfile"],
"keyfile": self.certs["keyfile"],
"truststore": self.certs["truststore"],
"priority_string": "SECURE128:+RSA:-VERS-TLS1.0:-ECDHE-ECDSA",
}
return super().configuration_parameters() | { "kmip_hosts": { self.kmip_host: options } }
def kmip_serve(self):
"""KMIP server loop"""
self.kmip_server_wrapper.serve()
async def __aenter__(self):
assert os.path.exists(self.certs["certfile"])
assert os.path.exists(self.certs["keyfile"])
assert os.path.exists(self.certs["truststore"])
kmiplog = logging.getLogger("kmip.server")
kmiplog.handlers.clear() # make pykmip shut up a bit. log will written to log file (setup in init below)
self.kmip_server_wrapper = KMIPServerWrapper(
hostname="127.0.0.1",
config_path=None,
certificate_path=self.certs["certfile"],
policy_path=self.tmpdir.strpath,
key_path=self.certs["keyfile"],
ca_path=self.certs["truststore"],
auth_suite="TLS1.2",
database_path=':memory:',
logging_level='DEBUG',
log_path=os.path.join(self.tmpdir, "pykmip.log"),
enable_tls_client_auth=False,
)
self.kmip_server_wrapper.start()
self.kmip_port = self.kmip_server_wrapper.port
assert len(kmiplog.handlers) == 1
self.kmip_thread = threading.Thread(name="kmip server", target=self.kmip_serve, daemon=True)
self.kmip_thread.start()
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
# pylint: disable=bare-except
if self.kmip_server_wrapper is not None:
self.kmip_server_wrapper.stop()
self.kmip_server_wrapper = None
if self.kmip_thread is not None:
self.kmip_thread.join()
self.kmip_thread = None
def additional_cf_options(self):
return super().additional_cf_options() | {"kmip_host": self.kmip_host}
def require_restart(self):
return True
def supported_cipher(self, cipher_algorithm, secret_key_strength):
# Our KMIP server is not configured to support this configuration.
# Test fails with error: Invalid key data length 80 for RC2/CBC and kmip.
# Decided (Roy) don't test it
return not ("RC2" in cipher_algorithm and secret_key_strength == 80)
class KMSKeyProviderFactory(KeyProviderFactory):
"""KMSKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(KMSKeyProviderFactory, self).__init__(KeyProvider.kms, tmpdir)
self.tmpdir = tmpdir
self.master_key = "alias/Scylla-test"
self.kms_host = "kms_test"
self.endpoint_url = None
self.server = None
self.region = None
def _docker_args(self, host, port):
# pylint: disable=unused-argument
return ["-e", f'PORT={port}']
async def __aenter__(self):
master_key = os.getenv('KMS_KEY_ALIAS')
aws_region = os.getenv('KMS_AWS_REGION')
if master_key is None:
self.server = DockerizedServer("docker.io/nsmithuk/local-kms:3", self.tmpdir,
logfilenamebase="local-kms",
docker_args=self._docker_args,
success_string="Local KMS started on",
failure_string="address already in use"
)
await self.server.start()
self.endpoint_url = f'http://{self.server.host}:{self.server.port}'
self.create_master_key()
else:
if aws_region is None:
aws_region = 'us-east-1'
self.region = aws_region
self.master_key = master_key
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
if self.server is not None:
await self.server.stop()
self.server = None
@cached_property
def kms_client(self):
"""A boto client"""
return boto3.client("kms", endpoint_url=self.endpoint_url, region_name="None")
def create_master_key(self, alias_name:str = None):
# pylint: disable=broad-exception-caught
"""(re-)create key"""
# create master key
response = self.kms_client.create_key(Description="dtest", Tags=[{"TagKey": "Name", "TagValue": "dtest"}])
key_id = response["KeyMetadata"]["KeyId"]
if alias_name is None:
alias_name = self.master_key
try:
self.kms_client.delete_alias(AliasName=alias_name)
except Exception:
pass
self.kms_client.create_alias(AliasName=alias_name, TargetKeyId=key_id)
return key_id
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"master_key": self.master_key
}
if self.endpoint_url is not None:
options['endpoint'] = self.endpoint_url
if self.region is not None:
options['aws_regions'] = self.region
return super().configuration_parameters() | { "kms_hosts": { self.kms_host: options } }
def additional_cf_options(self):
return super().additional_cf_options() | {"kms_host": self.kms_host}
def supported_cipher(self, cipher_algorithm, secret_key_strength):
return secret_key_strength >= 128
def require_restart(self):
return True
class AzureKeyProviderFactory(KeyProviderFactory):
"""AzureKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(AzureKeyProviderFactory, self).__init__(KeyProvider.azure, tmpdir)
self.tmpdir = tmpdir
self.azure_host = "azure_test"
self.azure_server = None
self.host = None;
self.port = 0
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"master_key": f"http://{self.host}:{self.port}/mock-key",
"azure_tenant_id": "00000000-1111-2222-3333-444444444444",
"azure_client_id": "mock-client-id",
"azure_client_secret": "mock-client-secret",
"azure_authority_host": f"http://{self.host}:{self.port}"
}
return super().configuration_parameters() | { "azure_hosts": { self.azure_host: options } }
async def __aenter__(self):
self.azure_server = MockAzureVaultServer("127.0.0.1", 0, logging.getLogger('azure-vault-server'))
await self.azure_server.start()
self.host = self.azure_server.server_address[0]
self.port = self.azure_server.server_address[1]
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
if self.azure_server is not None:
await self.azure_server.stop()
self.azure_server = None
def additional_cf_options(self):
return super().additional_cf_options() | {"azure_host": self.azure_host}
def require_restart(self):
return True
def make_key_provider_factory(provider: KeyProvider, tmpdir, scylla_exe):
"""Create key provider factory for enum"""
res = None
if provider == KeyProvider.local:
res = LocalFileSystemKeyProviderFactory(tmpdir)
elif provider == KeyProvider.replicated:
res = ReplicatedKeyProviderFactory(tmpdir, scylla_exe)
elif provider == KeyProvider.kmip:
res = KmipKeyProviderFactory(tmpdir)
elif provider == KeyProvider.kms:
res = KMSKeyProviderFactory(tmpdir)
elif provider == KeyProvider.azure:
res = AzureKeyProviderFactory(tmpdir)
else:
raise RuntimeError(f'Unknown key_provider: {provider}')
return res