cql-pytest: add new_user and new_session utils
These helpers can be used to create a new user and connect to the cluster using custom credentials to log in.
This commit is contained in:
@@ -11,13 +11,11 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cluster import Cluster, ConsistencyLevel, ExecutionProfile, EXEC_PROFILE_DEFAULT
|
||||
from cassandra.policies import RoundRobinPolicy
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.connection import DRIVER_NAME, DRIVER_VERSION
|
||||
import ssl
|
||||
|
||||
from util import unique_name, new_test_table
|
||||
from util import unique_name, new_test_table, cql_session
|
||||
|
||||
# By default, tests run against a CQL server (Scylla or Cassandra) listening
|
||||
# on localhost:9042. Add the --host and --port options to allow overiding
|
||||
@@ -36,34 +34,15 @@ def pytest_addoption(parser):
|
||||
# We use scope="session" so that all tests will reuse the same client object.
|
||||
@pytest.fixture(scope="session")
|
||||
def cql(request):
|
||||
profile = ExecutionProfile(
|
||||
load_balancing_policy=RoundRobinPolicy(),
|
||||
consistency_level=ConsistencyLevel.LOCAL_QUORUM,
|
||||
serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL,
|
||||
# The default timeout (in seconds) for execute() commands is 10, which
|
||||
# should have been more than enough, but in some extreme cases with a
|
||||
# very slow debug build running on a very busy machine and a very slow
|
||||
# request (e.g., a DROP KEYSPACE needing to drop multiple tables)
|
||||
# 10 seconds may not be enough, so let's increase it. See issue #7838.
|
||||
request_timeout = 120)
|
||||
if request.config.getoption('ssl'):
|
||||
# Scylla does not support any earlier TLS protocol. If you try,
|
||||
# you will get mysterious EOF errors (see issue #6971) :-(
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
else:
|
||||
ssl_context = None
|
||||
cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile},
|
||||
contact_points=[request.config.getoption('host')],
|
||||
port=int(request.config.getoption('port')),
|
||||
# TODO: make the protocol version an option, to allow testing with
|
||||
# different versions. If we drop this setting completely, it will
|
||||
# mean pick the latest version supported by the client and the server.
|
||||
protocol_version=4,
|
||||
# Use the default superuser credentials, which work for both Scylla and Cassandra
|
||||
auth_provider=PlainTextAuthProvider(username='cassandra', password='cassandra'),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
return cluster.connect()
|
||||
# Use the default superuser credentials, which work for both Scylla and Cassandra
|
||||
with cql_session(request.config.getoption('host'),
|
||||
request.config.getoption('port'),
|
||||
request.config.getoption('ssl'),
|
||||
username="cassandra",
|
||||
password="cassandra"
|
||||
) as session:
|
||||
yield session
|
||||
session.shutdown()
|
||||
|
||||
# A function-scoped autouse=True fixture allows us to test after every test
|
||||
# that the CQL connection is still alive - and if not report the test which
|
||||
|
||||
@@ -14,6 +14,10 @@ import os
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cluster import Cluster, ConsistencyLevel, ExecutionProfile, EXEC_PROFILE_DEFAULT
|
||||
from cassandra.policies import RoundRobinPolicy
|
||||
|
||||
def random_string(length=10, chars=string.ascii_uppercase + string.digits):
|
||||
return ''.join(random.choice(chars) for x in range(length))
|
||||
|
||||
@@ -139,6 +143,55 @@ def new_secondary_index(cql, table, column, name='', extra=''):
|
||||
finally:
|
||||
cql.execute(f"DROP INDEX {keyspace}.{name}")
|
||||
|
||||
# Helper function for establishing a connection with given username and password
|
||||
@contextmanager
|
||||
def cql_session(host, port, ssl, username, password):
|
||||
profile = ExecutionProfile(
|
||||
load_balancing_policy=RoundRobinPolicy(),
|
||||
consistency_level=ConsistencyLevel.LOCAL_QUORUM,
|
||||
serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL,
|
||||
# The default timeout (in seconds) for execute() commands is 10, which
|
||||
# should have been more than enough, but in some extreme cases with a
|
||||
# very slow debug build running on a very busy machine and a very slow
|
||||
# request (e.g., a DROP KEYSPACE needing to drop multiple tables)
|
||||
# 10 seconds may not be enough, so let's increase it. See issue #7838.
|
||||
request_timeout = 120)
|
||||
if ssl:
|
||||
# Scylla does not support any earlier TLS protocol. If you try,
|
||||
# you will get mysterious EOF errors (see issue #6971) :-(
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
else:
|
||||
ssl_context = None
|
||||
cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile},
|
||||
contact_points=[host],
|
||||
port=int(port),
|
||||
# TODO: make the protocol version an option, to allow testing with
|
||||
# different versions. If we drop this setting completely, it will
|
||||
# mean pick the latest version supported by the client and the server.
|
||||
protocol_version=4,
|
||||
auth_provider=PlainTextAuthProvider(username=username, password=password),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
yield cluster.connect()
|
||||
cluster.shutdown()
|
||||
|
||||
@contextmanager
|
||||
def new_user(cql, username=''):
|
||||
if not username:
|
||||
username = unique_name()
|
||||
cql.execute(f"CREATE ROLE {username} WITH PASSWORD = '{username}' AND LOGIN = true")
|
||||
try:
|
||||
yield username
|
||||
finally:
|
||||
cql.execute(f"DROP ROLE {username}")
|
||||
|
||||
@contextmanager
|
||||
def new_session(cql, username):
|
||||
endpoint = cql.hosts[0].endpoint
|
||||
with cql_session(host=endpoint.address, port=endpoint.port, ssl=False, username=username, password=username) as session:
|
||||
yield session
|
||||
session.shutdown()
|
||||
|
||||
def project(column_name_string, rows):
|
||||
"""Returns a list of column values from each of the rows."""
|
||||
return [getattr(r, column_name_string) for r in rows]
|
||||
|
||||
Reference in New Issue
Block a user