mirror of
https://github.com/scylladb/scylladb.git
synced 2026-05-12 19:02:12 +00:00
test::pylib::kmip_wrapper: Modify to be usable by pytest fixtures
Add `serve` impl that does not mess with signals, and shutdown that does not mess with threads. Also speed up standalone shutdown to make boost tests less slow.
This commit is contained in:
@@ -1,13 +1,19 @@
|
||||
#
|
||||
# Copyright (C) 2025-present ScyllaDB
|
||||
#
|
||||
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
||||
#
|
||||
|
||||
import ssl
|
||||
import sys
|
||||
import functools
|
||||
|
||||
import socket
|
||||
import signal
|
||||
import sqlalchemy
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from kmip.services import auth
|
||||
from kmip.services.server.server import build_argument_parser
|
||||
from kmip.services.server.server import KmipServer
|
||||
from kmip.services.server.server import KmipServer, build_argument_parser, exceptions
|
||||
|
||||
# Helper wrapper for running pykmip in scylla testing. Needed for the following
|
||||
# reasons:
|
||||
@@ -22,19 +28,6 @@ from kmip.services.server.server import KmipServer
|
||||
# itself must be shared. We achieve that with the StaticPool.
|
||||
# https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#using-a-memory-database-in-multiple-threads
|
||||
|
||||
|
||||
def monkey_patch_create_engine():
|
||||
original_create_engine = sqlalchemy.create_engine
|
||||
|
||||
@functools.wraps(original_create_engine)
|
||||
def patched_create_engine(*args, **kwargs):
|
||||
if args and isinstance(args[0], str) and args[0].startswith('sqlite:///:memory:'):
|
||||
kwargs['poolclass'] = StaticPool
|
||||
return original_create_engine(*args, **kwargs)
|
||||
|
||||
sqlalchemy.create_engine = patched_create_engine
|
||||
|
||||
|
||||
class TLS13AuthenticationSuite(auth.TLS12AuthenticationSuite):
|
||||
"""
|
||||
An authentication suite used to establish secure network connections.
|
||||
@@ -53,10 +46,136 @@ class TLS13AuthenticationSuite(auth.TLS12AuthenticationSuite):
|
||||
super().__init__(cipher_suites)
|
||||
self._protocol = ssl.PROTOCOL_TLS_SERVER
|
||||
|
||||
class KMIPServerWrapper:
|
||||
"""Wrapper for PyKMIP server"""
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.kmip_server = None
|
||||
self.original_create_engine = None
|
||||
#self.original_wrap_socket = None
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
# pylint: disable=protected-access
|
||||
"""Listening port"""
|
||||
return self.kmip_server._socket.getsockname()[1]
|
||||
|
||||
def serve(self):
|
||||
"""server wrapper"""
|
||||
self._serve_in_thread()
|
||||
|
||||
def _serve_in_thread(self):
|
||||
# pylint: disable=protected-access,broad-exception-caught
|
||||
"""Terrible copy of serve function, but without signal handlers"""
|
||||
kmip_server = self.kmip_server
|
||||
kmip_server._socket.listen(5)
|
||||
kmip_server._logger.info("Starting connection service...")
|
||||
|
||||
while kmip_server._is_serving:
|
||||
try:
|
||||
connection, address = kmip_server._socket.accept()
|
||||
except socket.timeout:
|
||||
# Setting the default socket timeout to break hung connections
|
||||
# will cause accept to periodically raise socket.timeout. This
|
||||
# is expected behavior, so ignore it and retry accept.
|
||||
pass
|
||||
except socket.error as e:
|
||||
kmip_server._logger.warning(
|
||||
"Error detected while establishing new connection."
|
||||
)
|
||||
kmip_server._logger.exception(e)
|
||||
except KeyboardInterrupt:
|
||||
kmip_server._logger.warning("Interrupting connection service.")
|
||||
kmip_server._is_serving = False
|
||||
break
|
||||
except Exception as e:
|
||||
if kmip_server._is_serving:
|
||||
kmip_server._logger.warning(
|
||||
"Error detected while establishing new connection."
|
||||
)
|
||||
kmip_server._logger.exception(e)
|
||||
else:
|
||||
kmip_server._setup_connection_handler(connection, address)
|
||||
|
||||
kmip_server._logger.info("Stopping connection service.")
|
||||
|
||||
def shutdown(self):
|
||||
"""Stop serving"""
|
||||
# pylint: disable=protected-access
|
||||
self.kmip_server._logger.info("Shutting down server socket handler.")
|
||||
self.kmip_server._is_serving = False
|
||||
self.kmip_server._socket.shutdown(socket.SHUT_RDWR)
|
||||
self.kmip_server._socket.close()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
"""start"""
|
||||
self.original_create_engine = sqlalchemy.create_engine
|
||||
|
||||
@functools.wraps(self.original_create_engine)
|
||||
def patched_create_engine(*args, **kwargs):
|
||||
if args and isinstance(args[0], str) and args[0].startswith('sqlite:///:memory:'):
|
||||
kwargs['poolclass'] = StaticPool
|
||||
return self.original_create_engine(*args, **kwargs)
|
||||
|
||||
def fake_wrap_ssl(sock, keyfile=None, certfile=None,
|
||||
server_side=False, cert_reqs=ssl.CERT_NONE,
|
||||
ssl_version=ssl.PROTOCOL_TLS, ca_certs=None,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True,
|
||||
ciphers=None):
|
||||
ctxt = ssl.SSLContext(protocol = ssl_version)
|
||||
ctxt.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||
ctxt.verify_mode = cert_reqs
|
||||
ctxt.load_verify_locations(cafile=ca_certs)
|
||||
ctxt.set_ciphers(ciphers)
|
||||
return ctxt.wrap_socket(sock, server_side=server_side
|
||||
, do_handshake_on_connect=do_handshake_on_connect
|
||||
, suppress_ragged_eofs=suppress_ragged_eofs)
|
||||
|
||||
sqlalchemy.create_engine = patched_create_engine
|
||||
ssl.wrap_socket = fake_wrap_ssl
|
||||
|
||||
# Create and start the server.
|
||||
self.kmip_server = KmipServer(**self.kwargs)
|
||||
# Fix TLS. Try to get this into mainline project, but that will take time...
|
||||
self.kmip_server.auth_suite = TLS13AuthenticationSuite(self.kmip_server.auth_suite.ciphers)
|
||||
# force port to zero -> select dynamically
|
||||
self.kmip_server.config.settings['port'] = 0
|
||||
self.kmip_server.start()
|
||||
|
||||
def stop(self):
|
||||
# pylint: disable=protected-access,broad-exception-caught
|
||||
"""stop"""
|
||||
kmip_server = self.kmip_server
|
||||
kmip_server._logger.info("Stopping...")
|
||||
self.shutdown()
|
||||
# KMIPServer stop is somewhat broken in that it assumes all threads belong to it.
|
||||
# We can really just ignore the serving threads. They are daemons and are either
|
||||
# done or not important at this point.
|
||||
|
||||
if hasattr(kmip_server, "policy_monitor"):
|
||||
try:
|
||||
kmip_server.policy_monitor.stop()
|
||||
kmip_server.policy_monitor.join()
|
||||
except Exception as e:
|
||||
kmip_server._logger.exception(e)
|
||||
raise exceptions.ShutdownError("Server failed to clean up the policy monitor.")
|
||||
|
||||
sqlalchemy.create_engine = self.original_create_engine
|
||||
|
||||
|
||||
def main():
|
||||
"""Called from parent process"""
|
||||
# Build argument parser and parser command-line arguments.
|
||||
parser = build_argument_parser()
|
||||
opts, args = parser.parse_args(sys.argv[1:])
|
||||
opts, _ = parser.parse_args(sys.argv[1:])
|
||||
|
||||
kwargs = {}
|
||||
if opts.hostname:
|
||||
@@ -86,37 +205,25 @@ def main():
|
||||
|
||||
kwargs['live_policies'] = True
|
||||
|
||||
monkey_patch_create_engine()
|
||||
|
||||
# Create and start the server.
|
||||
s = KmipServer(**kwargs)
|
||||
# Fix TLS. Try to get this into mainline project, but that will take time...
|
||||
s.auth_suite = TLS13AuthenticationSuite(s.auth_suite.ciphers)
|
||||
# force port to zero -> select dynamically
|
||||
s.config.settings['port'] = 0
|
||||
|
||||
def fake_wrap_ssl(sock, keyfile=None, certfile=None,
|
||||
server_side=False, cert_reqs=ssl.CERT_NONE,
|
||||
ssl_version=ssl.PROTOCOL_TLS, ca_certs=None,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True,
|
||||
ciphers=None):
|
||||
ctxt = ssl.SSLContext(protocol = ssl_version)
|
||||
ctxt.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||
ctxt.verify_mode = cert_reqs
|
||||
ctxt.load_verify_locations(cafile=ca_certs)
|
||||
ctxt.set_ciphers(ciphers)
|
||||
return ctxt.wrap_socket(sock, server_side=server_side
|
||||
, do_handshake_on_connect=do_handshake_on_connect
|
||||
, suppress_ragged_eofs=suppress_ragged_eofs)
|
||||
|
||||
ssl.wrap_socket = fake_wrap_ssl
|
||||
|
||||
s = KMIPServerWrapper(**kwargs)
|
||||
print("Starting...")
|
||||
|
||||
with s:
|
||||
print("Listening on {}".format(s._socket.getsockname()[1]))
|
||||
print(f'Listening on {s.port}')
|
||||
sys.stdout.flush()
|
||||
|
||||
# place signal handling here, and just do a throw
|
||||
# to escape the serve loop. We will not wait for daemons
|
||||
# or anything, just exit. This makes tests (boost) using
|
||||
# this _not_ wait 10s at exit.
|
||||
def _signal_handler(signal_number, stack_frame):
|
||||
# pylint: disable=protected-access,unused-argument
|
||||
s.kmip_server._is_serving = False
|
||||
raise KeyboardInterrupt("signal received")
|
||||
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
|
||||
s.serve()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user