From 91db8583f8a066bd31db5dc4e7b0215651983817 Mon Sep 17 00:00:00 2001 From: Calle Wilund Date: Tue, 21 Oct 2025 08:30:12 +0000 Subject: [PATCH] test::pylib::kmip_wrapper: Modify to be usable by pytest fixtures Add `serve` impl that does not mess with signals, and shutdown that does not mess with threads. Also speed up standalone shutdown to make boost tests less slow. --- test/pylib/kmip_wrapper.py | 195 ++++++++++++++++++++++++++++--------- 1 file changed, 151 insertions(+), 44 deletions(-) diff --git a/test/pylib/kmip_wrapper.py b/test/pylib/kmip_wrapper.py index 85d6486229..7fbf59836e 100644 --- a/test/pylib/kmip_wrapper.py +++ b/test/pylib/kmip_wrapper.py @@ -1,13 +1,19 @@ +# +# Copyright (C) 2025-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 +# + import ssl import sys import functools - +import socket +import signal import sqlalchemy from sqlalchemy.pool import StaticPool from kmip.services import auth -from kmip.services.server.server import build_argument_parser -from kmip.services.server.server import KmipServer +from kmip.services.server.server import KmipServer, build_argument_parser, exceptions # Helper wrapper for running pykmip in scylla testing. Needed for the following # reasons: @@ -22,19 +28,6 @@ from kmip.services.server.server import KmipServer # itself must be shared. We achieve that with the StaticPool. # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#using-a-memory-database-in-multiple-threads - -def monkey_patch_create_engine(): - original_create_engine = sqlalchemy.create_engine - - @functools.wraps(original_create_engine) - def patched_create_engine(*args, **kwargs): - if args and isinstance(args[0], str) and args[0].startswith('sqlite:///:memory:'): - kwargs['poolclass'] = StaticPool - return original_create_engine(*args, **kwargs) - - sqlalchemy.create_engine = patched_create_engine - - class TLS13AuthenticationSuite(auth.TLS12AuthenticationSuite): """ An authentication suite used to establish secure network connections. @@ -53,10 +46,136 @@ class TLS13AuthenticationSuite(auth.TLS12AuthenticationSuite): super().__init__(cipher_suites) self._protocol = ssl.PROTOCOL_TLS_SERVER +class KMIPServerWrapper: + """Wrapper for PyKMIP server""" + def __init__(self, **kwargs): + self.kwargs = kwargs + self.kmip_server = None + self.original_create_engine = None + #self.original_wrap_socket = None + + @property + def port(self): + # pylint: disable=protected-access + """Listening port""" + return self.kmip_server._socket.getsockname()[1] + + def serve(self): + """server wrapper""" + self._serve_in_thread() + + def _serve_in_thread(self): + # pylint: disable=protected-access,broad-exception-caught + """Terrible copy of serve function, but without signal handlers""" + kmip_server = self.kmip_server + kmip_server._socket.listen(5) + kmip_server._logger.info("Starting connection service...") + + while kmip_server._is_serving: + try: + connection, address = kmip_server._socket.accept() + except socket.timeout: + # Setting the default socket timeout to break hung connections + # will cause accept to periodically raise socket.timeout. This + # is expected behavior, so ignore it and retry accept. + pass + except socket.error as e: + kmip_server._logger.warning( + "Error detected while establishing new connection." + ) + kmip_server._logger.exception(e) + except KeyboardInterrupt: + kmip_server._logger.warning("Interrupting connection service.") + kmip_server._is_serving = False + break + except Exception as e: + if kmip_server._is_serving: + kmip_server._logger.warning( + "Error detected while establishing new connection." + ) + kmip_server._logger.exception(e) + else: + kmip_server._setup_connection_handler(connection, address) + + kmip_server._logger.info("Stopping connection service.") + + def shutdown(self): + """Stop serving""" + # pylint: disable=protected-access + self.kmip_server._logger.info("Shutting down server socket handler.") + self.kmip_server._is_serving = False + self.kmip_server._socket.shutdown(socket.SHUT_RDWR) + self.kmip_server._socket.close() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + """start""" + self.original_create_engine = sqlalchemy.create_engine + + @functools.wraps(self.original_create_engine) + def patched_create_engine(*args, **kwargs): + if args and isinstance(args[0], str) and args[0].startswith('sqlite:///:memory:'): + kwargs['poolclass'] = StaticPool + return self.original_create_engine(*args, **kwargs) + + def fake_wrap_ssl(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLS, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + ciphers=None): + ctxt = ssl.SSLContext(protocol = ssl_version) + ctxt.load_cert_chain(certfile=certfile, keyfile=keyfile) + ctxt.verify_mode = cert_reqs + ctxt.load_verify_locations(cafile=ca_certs) + ctxt.set_ciphers(ciphers) + return ctxt.wrap_socket(sock, server_side=server_side + , do_handshake_on_connect=do_handshake_on_connect + , suppress_ragged_eofs=suppress_ragged_eofs) + + sqlalchemy.create_engine = patched_create_engine + ssl.wrap_socket = fake_wrap_ssl + + # Create and start the server. + self.kmip_server = KmipServer(**self.kwargs) + # Fix TLS. Try to get this into mainline project, but that will take time... + self.kmip_server.auth_suite = TLS13AuthenticationSuite(self.kmip_server.auth_suite.ciphers) + # force port to zero -> select dynamically + self.kmip_server.config.settings['port'] = 0 + self.kmip_server.start() + + def stop(self): + # pylint: disable=protected-access,broad-exception-caught + """stop""" + kmip_server = self.kmip_server + kmip_server._logger.info("Stopping...") + self.shutdown() + # KMIPServer stop is somewhat broken in that it assumes all threads belong to it. + # We can really just ignore the serving threads. They are daemons and are either + # done or not important at this point. + + if hasattr(kmip_server, "policy_monitor"): + try: + kmip_server.policy_monitor.stop() + kmip_server.policy_monitor.join() + except Exception as e: + kmip_server._logger.exception(e) + raise exceptions.ShutdownError("Server failed to clean up the policy monitor.") + + sqlalchemy.create_engine = self.original_create_engine + + def main(): + """Called from parent process""" # Build argument parser and parser command-line arguments. parser = build_argument_parser() - opts, args = parser.parse_args(sys.argv[1:]) + opts, _ = parser.parse_args(sys.argv[1:]) kwargs = {} if opts.hostname: @@ -86,37 +205,25 @@ def main(): kwargs['live_policies'] = True - monkey_patch_create_engine() - - # Create and start the server. - s = KmipServer(**kwargs) - # Fix TLS. Try to get this into mainline project, but that will take time... - s.auth_suite = TLS13AuthenticationSuite(s.auth_suite.ciphers) - # force port to zero -> select dynamically - s.config.settings['port'] = 0 - - def fake_wrap_ssl(sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLS, ca_certs=None, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - ciphers=None): - ctxt = ssl.SSLContext(protocol = ssl_version) - ctxt.load_cert_chain(certfile=certfile, keyfile=keyfile) - ctxt.verify_mode = cert_reqs - ctxt.load_verify_locations(cafile=ca_certs) - ctxt.set_ciphers(ciphers) - return ctxt.wrap_socket(sock, server_side=server_side - , do_handshake_on_connect=do_handshake_on_connect - , suppress_ragged_eofs=suppress_ragged_eofs) - - ssl.wrap_socket = fake_wrap_ssl - + s = KMIPServerWrapper(**kwargs) print("Starting...") with s: - print("Listening on {}".format(s._socket.getsockname()[1])) + print(f'Listening on {s.port}') sys.stdout.flush() + + # place signal handling here, and just do a throw + # to escape the serve loop. We will not wait for daemons + # or anything, just exit. This makes tests (boost) using + # this _not_ wait 10s at exit. + def _signal_handler(signal_number, stack_frame): + # pylint: disable=protected-access,unused-argument + s.kmip_server._is_serving = False + raise KeyboardInterrupt("signal received") + + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + s.serve() if __name__ == '__main__':