diff --git a/test/cluster/test_audit.py b/test/cluster/test_audit.py index 1c4b5a72f2..3ae0fbbaf6 100644 --- a/test/cluster/test_audit.py +++ b/test/cluster/test_audit.py @@ -273,6 +273,7 @@ class AuditEntry: statement: str table: str user: str + source: str = "127.0.0.1" class AuditBackend: @@ -449,6 +450,13 @@ class AuditBackendSyslog(AuditBackend): entries.append(self.line_to_row(line, idx)) return { self.audit_mode(): entries } + @staticmethod + def _parse_address(addr_port): + """Extract IP from 'ip:port' (IPv4) or '[ip]:port' (IPv6).""" + if addr_port.startswith("["): + return addr_port[1:addr_port.index("]")] + return addr_port.split(":")[0] + def line_to_row(self, line, idx): metadata, data = line.split(": ", 1) data = "".join(data.splitlines()) # Remove newlines @@ -460,9 +468,9 @@ class AuditBackendSyslog(AuditBackend): # and make sure it doesn't change during the test (e.g. when the test is running at 23:59:59) date = datetime.datetime(2000, 1, 1, 0, 0) - node = match.group("node").split(":")[0] + node = self._parse_address(match.group("node")) statement = match.group("query").replace("\\", "") - source = match.group("client_ip").split(":")[0] + source = self._parse_address(match.group("client_ip")) event_time = uuid.UUID(int=idx) t = self.named_tuple_factory(date, node, event_time, match.group("category"), match.group("cl"), match.group("error") == "true", match.group("keyspace"), statement, source, match.group("table"), match.group("username")) return t @@ -582,6 +590,7 @@ class CQLAuditTester(AuditTester): user="anonymous", cl="ONE", error=False, + source="127.0.0.1", ): self.assert_audit_row_fields(row) assert row.node in self.server_addresses @@ -590,7 +599,7 @@ class CQLAuditTester(AuditTester): assert row.error == error assert row.keyspace_name == ks assert row.operation == statement - assert row.source == "127.0.0.1" + assert row.source == source assert row.table_name == table assert row.username == user @@ -814,7 +823,7 @@ class CQLAuditTester(AuditTester): sorted_new_rows = sorted(new_rows, key=lambda row: (row.node, row.category, row.consistency, row.error, row.keyspace_name, row.operation, row.source, row.table_name, row.username)) assert len(sorted_new_rows) == len(expected_entries) for row, entry in zip(sorted_new_rows, sorted(expected_entries)): - self.assert_audit_row_eq(row, entry.category, entry.statement, entry.table, entry.ks, entry.user, entry.cl, entry.error) + self.assert_audit_row_eq(row, entry.category, entry.statement, entry.table, entry.ks, entry.user, entry.cl, entry.error, entry.source) async def verify_keyspace(self, audit_settings=None, helper=None): """