Files
scylladb/test/pylib/encryption_provider.py
Calle Wilund 3e8a9a0beb pytest: use ephemeral port publish for docker mock servers
Changes dockerized_service to use ephermal port publish, and
query the published port from podman/docker.
Modifies client code to use slightly changed usage syntax.
2026-03-11 12:32:01 +01:00

306 lines
12 KiB
Python

#
# Copyright (C) 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
import os
import asyncio
import logging
import threading
from enum import Enum
from functools import cached_property
from test import TEST_DIR
from test.pylib.dockerized_service import DockerizedServer
from test.pylib.kmip_wrapper import KMIPServerWrapper
from test.pylib.azure_vault_server_mock import MockAzureVaultServer
import boto3
class KeyProvider(Enum):
"""Enumeration of key providers in scylla"""
# pylint: disable=invalid-name
local = "LocalFileSystemKeyProviderFactory"
replicated = "ReplicatedKeyProviderFactory"
kmip = "KmipKeyProviderFactory"
kms = "KmsKeyProviderFactory"
azure = "AzureKeyProviderFactory"
class KeyProviderFactory:
"""Base class for provider factories"""
def __init__(self, key_provider : KeyProvider, tmpdir):
self.key_provider = key_provider
self.system_key_location = os.path.join(tmpdir, "resources/system_keys")
async def __aenter__(self):
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
pass
def supported_cipher(self, cipher_algorithm, secret_key_strength):
# pylint: disable=unused-argument
"""Is cipher type supported by provider"""
return True
def require_restart(self):
"""Must servers be restarted on keyspace parameter changes?"""
return False
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
return {"system_key_directory": self.system_key_location}
def additional_cf_options(self) -> dict[str, str]:
# pylint: disable=unused-argument
"""keyspace options"""
if self.key_provider:
return {"key_provider": self.key_provider.value}
return {}
class LocalFileSystemKeyProviderFactory(KeyProviderFactory):
"""LocalFileSystemKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(LocalFileSystemKeyProviderFactory, self).__init__(KeyProvider.local, tmpdir)
self.secret_file = os.path.join(tmpdir, "test/node1/conf/data_encryption_keys")
def additional_cf_options(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
return super().additional_cf_options() | {"secret_key_file": self.secret_file}
class ReplicatedKeyProviderFactory(KeyProviderFactory):
"""ReplicatedKeyProviderFactory proxy"""
def __init__(self, tmpdir, scylla_exe):
super(ReplicatedKeyProviderFactory, self).__init__(KeyProvider.replicated, tmpdir)
self.system_key_file_name = "system_key"
self.scylla_exe = scylla_exe
async def __aenter__(self):
await super().__aenter__()
args = ["local-file-key-generator", "generate", os.path.join(self.system_key_location, self.system_key_file_name)]
proc = await asyncio.create_subprocess_exec(self.scylla_exe, *args, stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f'Could not generate system key: {stderr.decode()}')
return self
def additional_cf_options(self):
return super().additional_cf_options() | {"system_key": self.system_key_file_name}
class KmipKeyProviderFactory(KeyProviderFactory):
"""KmipKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(KmipKeyProviderFactory, self).__init__(KeyProvider.kmip, tmpdir)
self.tmpdir = tmpdir
self.kmip_server_wrapper = None
self.kmip_host = "kmip_test"
self.kmip_port = 0
self.kmip_server = None
self.kmip_thread = None
resourcedir = os.path.join(TEST_DIR, "pylib/resources")
# TODO: we really need certs generated on test run
self.certs = { "certfile": os.path.join(resourcedir, "scylla.crt"),
"keyfile": os.path.join(resourcedir, "scylla.key"),
"truststore": os.path.join(resourcedir, "scylla.crt")
}
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"hosts": "127.0.0.1:" + str(self.kmip_port),
"certificate": self.certs["certfile"],
"keyfile": self.certs["keyfile"],
"truststore": self.certs["truststore"],
"priority_string": "SECURE128:+RSA:-VERS-TLS1.0:-ECDHE-ECDSA",
}
return super().configuration_parameters() | { "kmip_hosts": { self.kmip_host: options } }
def kmip_serve(self):
"""KMIP server loop"""
self.kmip_server_wrapper.serve()
async def __aenter__(self):
assert os.path.exists(self.certs["certfile"])
assert os.path.exists(self.certs["keyfile"])
assert os.path.exists(self.certs["truststore"])
kmiplog = logging.getLogger("kmip.server")
kmiplog.handlers.clear() # make pykmip shut up a bit. log will written to log file (setup in init below)
self.kmip_server_wrapper = KMIPServerWrapper(
hostname="127.0.0.1",
config_path=None,
certificate_path=self.certs["certfile"],
policy_path=self.tmpdir.strpath,
key_path=self.certs["keyfile"],
ca_path=self.certs["truststore"],
auth_suite="TLS1.2",
database_path=':memory:',
logging_level='DEBUG',
log_path=os.path.join(self.tmpdir, "pykmip.log"),
enable_tls_client_auth=False,
)
self.kmip_server_wrapper.start()
self.kmip_port = self.kmip_server_wrapper.port
assert len(kmiplog.handlers) == 1
self.kmip_thread = threading.Thread(name="kmip server", target=self.kmip_serve, daemon=True)
self.kmip_thread.start()
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
# pylint: disable=bare-except
if self.kmip_server_wrapper is not None:
self.kmip_server_wrapper.stop()
self.kmip_server_wrapper = None
if self.kmip_thread is not None:
self.kmip_thread.join()
self.kmip_thread = None
def additional_cf_options(self):
return super().additional_cf_options() | {"kmip_host": self.kmip_host}
def require_restart(self):
return True
def supported_cipher(self, cipher_algorithm, secret_key_strength):
# Our KMIP server is not configured to support this configuration.
# Test fails with error: Invalid key data length 80 for RC2/CBC and kmip.
# Decided (Roy) don't test it
return not ("RC2" in cipher_algorithm and secret_key_strength == 80)
class KMSKeyProviderFactory(KeyProviderFactory):
"""KMSKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(KMSKeyProviderFactory, self).__init__(KeyProvider.kms, tmpdir)
self.tmpdir = tmpdir
self.master_key = "alias/Scylla-test"
self.kms_host = "kms_test"
self.endpoint_url = None
self.server = None
self.region = None
async def __aenter__(self):
master_key = os.getenv('KMS_KEY_ALIAS')
aws_region = os.getenv('KMS_AWS_REGION')
if master_key is None:
self.server = DockerizedServer("docker.io/nsmithuk/local-kms:3", self.tmpdir,
logfilenamebase="local-kms",
success_string="Local KMS started on",
failure_string="address already in use",
port=8080
)
await self.server.start()
self.endpoint_url = f'http://{self.server.host}:{self.server.port}'
self.create_master_key()
else:
if aws_region is None:
aws_region = 'us-east-1'
self.region = aws_region
self.master_key = master_key
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
if self.server is not None:
await self.server.stop()
self.server = None
@cached_property
def kms_client(self):
"""A boto client"""
return boto3.client("kms", endpoint_url=self.endpoint_url, region_name="None")
def create_master_key(self, alias_name:str = None):
# pylint: disable=broad-exception-caught
"""(re-)create key"""
# create master key
response = self.kms_client.create_key(Description="dtest", Tags=[{"TagKey": "Name", "TagValue": "dtest"}])
key_id = response["KeyMetadata"]["KeyId"]
if alias_name is None:
alias_name = self.master_key
try:
self.kms_client.delete_alias(AliasName=alias_name)
except Exception:
pass
self.kms_client.create_alias(AliasName=alias_name, TargetKeyId=key_id)
return key_id
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"master_key": self.master_key
}
if self.endpoint_url is not None:
options['endpoint'] = self.endpoint_url
if self.region is not None:
options['aws_regions'] = self.region
return super().configuration_parameters() | { "kms_hosts": { self.kms_host: options } }
def additional_cf_options(self):
return super().additional_cf_options() | {"kms_host": self.kms_host}
def supported_cipher(self, cipher_algorithm, secret_key_strength):
return secret_key_strength >= 128
def require_restart(self):
return True
class AzureKeyProviderFactory(KeyProviderFactory):
"""AzureKeyProviderFactory proxy"""
def __init__(self, tmpdir):
super(AzureKeyProviderFactory, self).__init__(KeyProvider.azure, tmpdir)
self.tmpdir = tmpdir
self.azure_host = "azure_test"
self.azure_server = None
self.host = None;
self.port = 0
def configuration_parameters(self) -> dict[str, str]:
"""scylla.conf entries for provider"""
options = {
"master_key": f"http://{self.host}:{self.port}/mock-key",
"azure_tenant_id": "00000000-1111-2222-3333-444444444444",
"azure_client_id": "mock-client-id",
"azure_client_secret": "mock-client-secret",
"azure_authority_host": f"http://{self.host}:{self.port}"
}
return super().configuration_parameters() | { "azure_hosts": { self.azure_host: options } }
async def __aenter__(self):
self.azure_server = MockAzureVaultServer("127.0.0.1", 0, logging.getLogger('azure-vault-server'))
await self.azure_server.start()
self.host = self.azure_server.server_address[0]
self.port = self.azure_server.server_address[1]
return self
async def __aexit__(self, exception_type, exception_value, exception_traceback):
if self.azure_server is not None:
await self.azure_server.stop()
self.azure_server = None
def additional_cf_options(self):
return super().additional_cf_options() | {"azure_host": self.azure_host}
def require_restart(self):
return True
def make_key_provider_factory(provider: KeyProvider, tmpdir, scylla_exe):
"""Create key provider factory for enum"""
res = None
if provider == KeyProvider.local:
res = LocalFileSystemKeyProviderFactory(tmpdir)
elif provider == KeyProvider.replicated:
res = ReplicatedKeyProviderFactory(tmpdir, scylla_exe)
elif provider == KeyProvider.kmip:
res = KmipKeyProviderFactory(tmpdir)
elif provider == KeyProvider.kms:
res = KMSKeyProviderFactory(tmpdir)
elif provider == KeyProvider.azure:
res = AzureKeyProviderFactory(tmpdir)
else:
raise RuntimeError(f'Unknown key_provider: {provider}')
return res