diff --git a/core/src/main/java/google/registry/keyring/KeyringModule.java b/core/src/main/java/google/registry/keyring/KeyringModule.java index 6cefbb851..c9989bfba 100644 --- a/core/src/main/java/google/registry/keyring/KeyringModule.java +++ b/core/src/main/java/google/registry/keyring/KeyringModule.java @@ -14,6 +14,7 @@ package google.registry.keyring; +import com.google.common.collect.ImmutableList; import dagger.Binds; import dagger.Module; import dagger.Provides; @@ -21,7 +22,6 @@ import google.registry.config.RegistryConfig.Config; import google.registry.keyring.api.Keyring; import google.registry.keyring.secretmanager.SecretManagerKeyring; import jakarta.inject.Singleton; -import java.util.Optional; /** Dagger module for {@link Keyring} */ @Module @@ -38,9 +38,10 @@ public abstract class KeyringModule { } @Provides - @Config("cloudSqlReplicaInstanceConnectionName") - public static Optional provideCloudSqlReplicaInstanceConnectionName(Keyring keyring) { - return Optional.ofNullable(keyring.getSqlReplicaConnectionName()); + @Config("cloudSqlReplicaInstanceConnectionNames") + public static ImmutableList provideCloudSqlReplicaInstanceConnectionNames( + Keyring keyring) { + return ImmutableList.copyOf(keyring.getSqlReplicaConnectionNames()); } @Provides diff --git a/core/src/main/java/google/registry/keyring/api/Keyring.java b/core/src/main/java/google/registry/keyring/api/Keyring.java index db3d45877..10bad414f 100644 --- a/core/src/main/java/google/registry/keyring/api/Keyring.java +++ b/core/src/main/java/google/registry/keyring/api/Keyring.java @@ -14,6 +14,7 @@ package google.registry.keyring.api; +import com.google.common.collect.ImmutableList; import javax.annotation.concurrent.ThreadSafe; import org.bouncycastle.openpgp.PGPKeyPair; import org.bouncycastle.openpgp.PGPPrivateKey; @@ -151,9 +152,17 @@ public interface Keyring extends AutoCloseable { /** Returns the Cloud SQL connection name of the primary database instance. */ String getSqlPrimaryConnectionName(); - /** Returns the Cloud SQL connection name of the replica database instance. */ + /** + * Returns the Cloud SQL connection name of the replica database instance. + * + *

Note: It is likely a better idea to use multiple replicas and {@link + * #getSqlReplicaConnectionNames()} instead. + */ String getSqlReplicaConnectionName(); + /** Returns the Cloud SQL connection names of the replica database instances. */ + ImmutableList getSqlReplicaConnectionNames(); + // Don't throw so try-with-resources works better. @Override void close(); diff --git a/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyring.java b/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyring.java index 6fc18b510..6ab7fbafa 100644 --- a/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyring.java +++ b/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyring.java @@ -17,6 +17,8 @@ package google.registry.keyring.secretmanager; import static com.google.common.base.CaseFormat.LOWER_HYPHEN; import static com.google.common.base.CaseFormat.UPPER_UNDERSCORE; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; import google.registry.keyring.api.KeySerializer; import google.registry.keyring.api.Keyring; import google.registry.keyring.api.KeyringException; @@ -66,7 +68,8 @@ public class SecretManagerKeyring implements Keyring { RDE_SSH_CLIENT_PUBLIC_STRING, SAFE_BROWSING_API_KEY, SQL_PRIMARY_CONN_NAME, - SQL_REPLICA_CONN_NAME; + SQL_REPLICA_CONN_NAME, + SQL_REPLICA_CONN_NAMES; String getLabel() { return UPPER_UNDERSCORE.to(LOWER_HYPHEN, name()); @@ -157,7 +160,25 @@ public class SecretManagerKeyring implements Keyring { @Override public String getSqlReplicaConnectionName() { - return getString(StringKeyLabel.SQL_REPLICA_CONN_NAME); + try { + return getString(StringKeyLabel.SQL_REPLICA_CONN_NAME); + } catch (KeyringException e) { + return null; + } + } + + @Override + public ImmutableList getSqlReplicaConnectionNames() { + try { + String names = getString(StringKeyLabel.SQL_REPLICA_CONN_NAMES); + return ImmutableList.copyOf( + Splitter.on('\n').trimResults().omitEmptyStrings().splitToList(names)); + } catch (KeyringException e) { + String replicaConnectionName = getSqlReplicaConnectionName(); + return replicaConnectionName == null + ? ImmutableList.of() + : ImmutableList.of(replicaConnectionName); + } } /** No persistent resources are maintained for this Keyring implementation. */ diff --git a/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdater.java b/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdater.java index 553c5ba1f..4d7370eda 100644 --- a/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdater.java +++ b/core/src/main/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdater.java @@ -34,6 +34,7 @@ import static google.registry.keyring.secretmanager.SecretManagerKeyring.StringK import static google.registry.keyring.secretmanager.SecretManagerKeyring.StringKeyLabel.SAFE_BROWSING_API_KEY; import static google.registry.keyring.secretmanager.SecretManagerKeyring.StringKeyLabel.SQL_PRIMARY_CONN_NAME; import static google.registry.keyring.secretmanager.SecretManagerKeyring.StringKeyLabel.SQL_REPLICA_CONN_NAME; +import static google.registry.keyring.secretmanager.SecretManagerKeyring.StringKeyLabel.SQL_REPLICA_CONN_NAMES; import static google.registry.util.PreconditionsUtils.checkArgumentNotNull; import com.google.common.flogger.FluentLogger; @@ -134,6 +135,10 @@ public final class SecretManagerKeyringUpdater { return setString(name, SQL_REPLICA_CONN_NAME); } + public SecretManagerKeyringUpdater setSqlReplicaConnectionNames(String names) { + return setString(names, SQL_REPLICA_CONN_NAMES); + } + /** * Persists the secrets in the Secret Manager. * diff --git a/core/src/main/java/google/registry/persistence/PersistenceModule.java b/core/src/main/java/google/registry/persistence/PersistenceModule.java index 7e1fdd9e2..9c5844fe1 100644 --- a/core/src/main/java/google/registry/persistence/PersistenceModule.java +++ b/core/src/main/java/google/registry/persistence/PersistenceModule.java @@ -15,6 +15,7 @@ package google.registry.persistence; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static google.registry.config.RegistryConfig.getHibernateConnectionIsolation; import static google.registry.config.RegistryConfig.getHibernateHikariConnectionTimeout; import static google.registry.config.RegistryConfig.getHibernateHikariIdleTimeout; @@ -28,6 +29,7 @@ import static google.registry.persistence.transaction.TransactionManagerFactory. import com.google.auth.oauth2.GoogleCredentials; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import dagger.BindsOptionalOf; @@ -36,6 +38,7 @@ import dagger.Module; import dagger.Provides; import google.registry.config.RegistryConfig.Config; import google.registry.persistence.transaction.CloudSqlCredentialSupplier; +import google.registry.persistence.transaction.DelegatingReplicaJpaTransactionManager; import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.persistence.transaction.JpaTransactionManagerImpl; import google.registry.persistence.transaction.TransactionManager; @@ -59,6 +62,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.Properties; +import java.util.Random; import java.util.function.Supplier; import javax.annotation.Nullable; import org.hibernate.cfg.Environment; @@ -264,16 +268,13 @@ public abstract class PersistenceModule { static JpaTransactionManager provideReadOnlyReplicaJpaTm( SqlCredentialStore credentialStore, @PartialCloudSqlConfigs ImmutableMap cloudSqlConfigs, - @Config("cloudSqlReplicaInstanceConnectionName") - Optional replicaInstanceConnectionName, - Clock clock) { + @Config("cloudSqlReplicaInstanceConnectionNames") + ImmutableList replicaInstanceConnectionNames, + Clock clock, + Random random) { HashMap overrides = Maps.newHashMap(cloudSqlConfigs); setSqlCredential(credentialStore, new RobotUser(RobotId.NOMULUS), overrides); - replicaInstanceConnectionName.ifPresent( - name -> overrides.put(HIKARI_DS_CLOUD_SQL_INSTANCE, name)); - overrides.put( - Environment.ISOLATION, TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ.name()); - return new JpaTransactionManagerImpl(create(overrides), clock, true); + return createReplicaJpaTm(overrides, replicaInstanceConnectionNames, clock, random); } @Provides @@ -281,15 +282,34 @@ public abstract class PersistenceModule { @BeamReadOnlyReplicaJpaTm static JpaTransactionManager provideBeamReadOnlyReplicaJpaTm( @BeamPipelineCloudSqlConfigs ImmutableMap beamCloudSqlConfigs, - @Config("cloudSqlReplicaInstanceConnectionName") - Optional replicaInstanceConnectionName, - Clock clock) { + @Config("cloudSqlReplicaInstanceConnectionNames") + ImmutableList replicaInstanceConnectionNames, + Clock clock, + Random random) { HashMap overrides = Maps.newHashMap(beamCloudSqlConfigs); - replicaInstanceConnectionName.ifPresent( - name -> overrides.put(HIKARI_DS_CLOUD_SQL_INSTANCE, name)); - overrides.put( + return createReplicaJpaTm(overrides, replicaInstanceConnectionNames, clock, random); + } + + private static JpaTransactionManager createReplicaJpaTm( + Map baseOverrides, + ImmutableList replicaInstanceConnectionNames, + Clock clock, + Random random) { + baseOverrides.put( Environment.ISOLATION, TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ.name()); - return new JpaTransactionManagerImpl(create(overrides), clock, true); + if (replicaInstanceConnectionNames.isEmpty()) { + return new JpaTransactionManagerImpl(create(baseOverrides), clock, true); + } + ImmutableList replicas = + replicaInstanceConnectionNames.stream() + .map( + name -> { + HashMap overrides = Maps.newHashMap(baseOverrides); + overrides.put(HIKARI_DS_CLOUD_SQL_INSTANCE, name); + return new JpaTransactionManagerImpl(create(overrides), clock, true); + }) + .collect(toImmutableList()); + return new DelegatingReplicaJpaTransactionManager(replicas, random); } /** Constructs the {@link EntityManagerFactory} instance. */ diff --git a/core/src/main/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManager.java b/core/src/main/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManager.java new file mode 100644 index 000000000..8d7b647d4 --- /dev/null +++ b/core/src/main/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManager.java @@ -0,0 +1,361 @@ +// Copyright 2026 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.persistence.transaction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableCollection; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import google.registry.model.ImmutableObject; +import google.registry.persistence.PersistenceModule.TransactionIsolationLevel; +import google.registry.persistence.VKey; +import jakarta.persistence.EntityManager; +import jakarta.persistence.Query; +import jakarta.persistence.TypedQuery; +import jakarta.persistence.criteria.CriteriaQuery; +import jakarta.persistence.metamodel.Metamodel; +import java.time.Instant; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Stream; +import org.joda.time.DateTime; + +/** + * A {@link JpaTransactionManager} that load-balances across multiple read-only replicas. + * + *

For each top-level transaction, one replica is chosen and used for the duration of the + * transaction. For non-transactional methods, a replica is chosen for each call. + */ +public class DelegatingReplicaJpaTransactionManager implements JpaTransactionManager { + + private final ImmutableList replicas; + private final Random random; + private static final AtomicLong nextId = new AtomicLong(1); + + private static final ThreadLocal activeReplica = new ThreadLocal<>(); + + public DelegatingReplicaJpaTransactionManager( + ImmutableList replicas, Random random) { + checkArgument(!replicas.isEmpty(), "At least one replica must be provided"); + this.replicas = replicas; + this.random = random; + } + + private JpaTransactionManager getReplica() { + JpaTransactionManager replica = activeReplica.get(); + if (replica != null) { + return replica; + } + return getRandomReplica(); + } + + private T runMaybeAssigningReplica(Function work) { + JpaTransactionManager existing = activeReplica.get(); + if (existing != null) { + return work.apply(existing); + } + JpaTransactionManager replica = getRandomReplica(); + activeReplica.set(replica); + try { + return work.apply(replica); + } finally { + activeReplica.remove(); + } + } + + private JpaTransactionManager getRandomReplica() { + return replicas.get(random.nextInt(replicas.size())); + } + + @Override + public boolean inTransaction() { + var replica = activeReplica.get(); + return replica != null && replica.inTransaction(); + } + + @Override + public void assertInTransaction() { + JpaTransactionManager replica = activeReplica.get(); + if (replica == null) { + throw new IllegalStateException("Not in a transaction"); + } + replica.assertInTransaction(); + } + + @Override + public long allocateId() { + return nextId.getAndIncrement(); + } + + @Override + public T transact(Callable work) { + return transact(null, work, false); + } + + @Override + public T transact(TransactionIsolationLevel isolationLevel, Callable work) { + return transact(isolationLevel, work, false); + } + + @Override + public T transactNoRetry(Callable work) { + return transactNoRetry(null, work, false); + } + + @Override + public T transactNoRetry(TransactionIsolationLevel isolationLevel, Callable work) { + return transactNoRetry(isolationLevel, work, false); + } + + @Override + public T reTransact(Callable work) { + return runMaybeAssigningReplica(replica -> replica.reTransact(work)); + } + + @Override + public void transact(ThrowingRunnable work) { + transact( + () -> { + work.run(); + return null; + }); + } + + @Override + public void transact(TransactionIsolationLevel isolationLevel, ThrowingRunnable work) { + transact( + isolationLevel, + () -> { + work.run(); + return null; + }); + } + + @Override + public void reTransact(ThrowingRunnable work) { + reTransact( + () -> { + work.run(); + return null; + }); + } + + @Override + public DateTime getTransactionTime() { + return getReplica().getTransactionTime(); + } + + @Override + public Instant getTxTime() { + return getReplica().getTxTime(); + } + + @Override + public void insert(Object entity) { + getReplica().insert(entity); + } + + @Override + public void insertAll(ImmutableCollection entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void insertAll(ImmutableObject... entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void put(Object entity) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void putAll(ImmutableObject... entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void putAll(ImmutableCollection entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void update(Object entity) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void updateAll(ImmutableCollection entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void updateAll(ImmutableObject... entities) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public boolean exists(Object entity) { + return getReplica().exists(entity); + } + + @Override + public boolean exists(VKey key) { + return getReplica().exists(key); + } + + @Override + public Optional loadByKeyIfPresent(VKey key) { + return getReplica().loadByKeyIfPresent(key); + } + + @Override + public ImmutableMap, T> loadByKeysIfPresent( + Iterable> keys) { + return getReplica().loadByKeysIfPresent(keys); + } + + @Override + public ImmutableList loadByEntitiesIfPresent(Iterable entities) { + return getReplica().loadByEntitiesIfPresent(entities); + } + + @Override + public T loadByKey(VKey key) { + return getReplica().loadByKey(key); + } + + @Override + public ImmutableMap, T> loadByKeys( + Iterable> keys) { + return getReplica().loadByKeys(keys); + } + + @Override + public T loadByEntity(T entity) { + return getReplica().loadByEntity(entity); + } + + @Override + public ImmutableList loadByEntities(Iterable entities) { + return getReplica().loadByEntities(entities); + } + + @Override + public ImmutableList loadAllOf(Class clazz) { + return getReplica().loadAllOf(clazz); + } + + @Override + public Stream loadAllOfStream(Class clazz) { + return getReplica().loadAllOfStream(clazz); + } + + @Override + public Optional loadSingleton(Class clazz) { + return getReplica().loadSingleton(clazz); + } + + @Override + public void delete(VKey key) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void delete(Iterable> keys) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public T delete(T entity) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public QueryComposer createQueryComposer(Class entity) { + return getReplica().createQueryComposer(entity); + } + + @Override + public EntityManager getStandaloneEntityManager() { + return getReplica().getStandaloneEntityManager(); + } + + @Override + public Metamodel getMetaModel() { + return getReplica().getMetaModel(); + } + + @Override + public EntityManager getEntityManager() { + return getReplica().getEntityManager(); + } + + @Override + public TypedQuery query(String sqlString, Class resultClass) { + return getReplica().query(sqlString, resultClass); + } + + @Override + public TypedQuery criteriaQuery(CriteriaQuery criteriaQuery) { + return getReplica().criteriaQuery(criteriaQuery); + } + + @Override + public Query query(String sqlString) { + return getReplica().query(sqlString); + } + + @Override + public void assertDelete(VKey key) { + throw new UnsupportedOperationException("This is a replica database"); + } + + @Override + public void teardown() { + for (JpaTransactionManager replica : replicas) { + replica.teardown(); + } + } + + @Override + public TransactionIsolationLevel getDefaultTransactionIsolationLevel() { + return replicas.get(0).getDefaultTransactionIsolationLevel(); + } + + @Override + public TransactionIsolationLevel getCurrentTransactionIsolationLevel() { + return getReplica().getCurrentTransactionIsolationLevel(); + } + + @Override + public T transact( + TransactionIsolationLevel isolationLevel, Callable work, boolean logSqlStatements) { + return runMaybeAssigningReplica( + replica -> replica.transact(isolationLevel, work, logSqlStatements)); + } + + @Override + public T transactNoRetry( + TransactionIsolationLevel isolationLevel, Callable work, boolean logSqlStatements) { + return runMaybeAssigningReplica( + replica -> replica.transactNoRetry(isolationLevel, work, logSqlStatements)); + } +} diff --git a/core/src/main/java/google/registry/tools/GetKeyringSecretCommand.java b/core/src/main/java/google/registry/tools/GetKeyringSecretCommand.java index 53837a3b1..e415a464c 100644 --- a/core/src/main/java/google/registry/tools/GetKeyringSecretCommand.java +++ b/core/src/main/java/google/registry/tools/GetKeyringSecretCommand.java @@ -14,6 +14,7 @@ package google.registry.tools; + import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; import google.registry.keyring.api.KeySerializer; @@ -95,6 +96,10 @@ final class GetKeyringSecretCommand implements Command { out.write(KeySerializer.serializeString(keyring.getSqlPrimaryConnectionName())); case SQL_REPLICA_CONN_NAME -> out.write(KeySerializer.serializeString(keyring.getSqlReplicaConnectionName())); + case SQL_REPLICA_CONN_NAMES -> + out.write( + KeySerializer.serializeString( + String.join("\n", keyring.getSqlReplicaConnectionNames()))); } } } diff --git a/core/src/main/java/google/registry/tools/UpdateKeyringSecretCommand.java b/core/src/main/java/google/registry/tools/UpdateKeyringSecretCommand.java index 07ffc2a96..3561a435f 100644 --- a/core/src/main/java/google/registry/tools/UpdateKeyringSecretCommand.java +++ b/core/src/main/java/google/registry/tools/UpdateKeyringSecretCommand.java @@ -100,6 +100,8 @@ final class UpdateKeyringSecretCommand implements Command { secretManagerKeyringUpdater.setSqlPrimaryConnectionName(deserializeString(input)); case SQL_REPLICA_CONN_NAME -> secretManagerKeyringUpdater.setSqlReplicaConnectionName(deserializeString(input)); + case SQL_REPLICA_CONN_NAMES -> + secretManagerKeyringUpdater.setSqlReplicaConnectionNames(deserializeString(input)); } secretManagerKeyringUpdater.update(); diff --git a/core/src/main/java/google/registry/tools/params/KeyringKeyName.java b/core/src/main/java/google/registry/tools/params/KeyringKeyName.java index c950fa4e2..a2e29a95c 100644 --- a/core/src/main/java/google/registry/tools/params/KeyringKeyName.java +++ b/core/src/main/java/google/registry/tools/params/KeyringKeyName.java @@ -38,5 +38,6 @@ public enum KeyringKeyName { RDE_STAGING_PUBLIC_KEY, SAFE_BROWSING_API_KEY, SQL_PRIMARY_CONN_NAME, - SQL_REPLICA_CONN_NAME + SQL_REPLICA_CONN_NAME, + SQL_REPLICA_CONN_NAMES } diff --git a/core/src/test/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdaterTest.java b/core/src/test/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdaterTest.java index c9fc30d32..90a7fa273 100644 --- a/core/src/test/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdaterTest.java +++ b/core/src/test/java/google/registry/keyring/secretmanager/SecretManagerKeyringUpdaterTest.java @@ -120,6 +120,23 @@ public class SecretManagerKeyringUpdaterTest { verifyPersistedSecret("sql-replica-conn-name", name); } + @Test + void sqlReplicaConnectionNames() { + String names = "name1\nname2"; + updater.setSqlReplicaConnectionNames(names).update(); + + assertThat(keyring.getSqlReplicaConnectionNames()).containsExactly("name1", "name2").inOrder(); + verifyPersistedSecret("sql-replica-conn-names", names); + } + + @Test + void sqlReplicaConnectionNames_fallback() { + String name = "name"; + updater.setSqlReplicaConnectionName(name).update(); + + assertThat(keyring.getSqlReplicaConnectionNames()).containsExactly(name); + } + @Test void marksdbDnlLoginAndPassword() { String secret = "marksdbDnlLoginAndPassword"; diff --git a/core/src/test/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManagerTest.java b/core/src/test/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManagerTest.java new file mode 100644 index 000000000..286793748 --- /dev/null +++ b/core/src/test/java/google/registry/persistence/transaction/DelegatingReplicaJpaTransactionManagerTest.java @@ -0,0 +1,135 @@ +// Copyright 2026 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.persistence.transaction; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import java.util.Random; +import java.util.concurrent.Callable; +import org.junit.jupiter.api.Test; + +/** Tests for {@link DelegatingReplicaJpaTransactionManager}. */ +public class DelegatingReplicaJpaTransactionManagerTest { + + private JpaTransactionManager replica1 = mock(JpaTransactionManager.class); + private JpaTransactionManager replica2 = mock(JpaTransactionManager.class); + private Random random = mock(Random.class); + private DelegatingReplicaJpaTransactionManager transactionManager = + new DelegatingReplicaJpaTransactionManager(ImmutableList.of(replica1, replica2), random); + + @Test + void testGetReplica_rotates() { + when(random.nextInt(2)).thenReturn(0).thenReturn(1); + + transactionManager.loadByKey(null); + verify(replica1).loadByKey(null); + + transactionManager.loadByKey(null); + verify(replica2).loadByKey(null); + } + + @Test + void testTransact_usesSameReplica() throws Exception { + when(random.nextInt(2)).thenReturn(1); + when(replica2.transact(any(), any(), anyBoolean())) + .thenAnswer( + invocation -> { + Callable work = invocation.getArgument(1); + return work.call(); + }); + + transactionManager.transact( + () -> { + transactionManager.loadByKey(null); + return null; + }); + + verify(replica2).transact(any(), any(), anyBoolean()); + // The loadByKey inside the transact should also use replica2. + verify(replica2).loadByKey(null); + // And it should NOT have called random again for the nested call. + verify(random).nextInt(2); + } + + @Test + void testTransactNoRetry_usesSameReplica() throws Exception { + when(random.nextInt(2)).thenReturn(0); + when(replica1.transactNoRetry(any(), any(), anyBoolean())) + .thenAnswer( + invocation -> { + Callable work = invocation.getArgument(1); + return work.call(); + }); + + transactionManager.transactNoRetry( + () -> { + transactionManager.loadByKey(null); + return null; + }); + + verify(replica1).transactNoRetry(any(), any(), anyBoolean()); + verify(replica1).loadByKey(null); + verify(random).nextInt(2); + } + + @Test + void testReTransactNoRetry_usesSameReplica() throws Exception { + when(random.nextInt(2)).thenReturn(0); + when(replica1.reTransact(any(Callable.class))) + .thenAnswer( + invocation -> { + Callable work = invocation.getArgument(0); + return work.call(); + }); + + transactionManager.reTransact( + () -> { + transactionManager.loadByKey(null); + return null; + }); + + verify(replica1).reTransact(any(Callable.class)); + verify(replica1).loadByKey(null); + verify(random).nextInt(2); + } + + @Test + void testInTransaction() { + when(random.nextInt(2)).thenReturn(0); + when(replica1.inTransaction()).thenReturn(true); + + // Not in transaction yet + assertThat(transactionManager.inTransaction()).isFalse(); + + transactionManager.transact( + () -> { + assertThat(transactionManager.inTransaction()).isTrue(); + return null; + }); + } + + @Test + void testTeardown_tearsDownAllReplicas() { + transactionManager.teardown(); + verify(replica1).teardown(); + verify(replica2).teardown(); + } +} diff --git a/core/src/test/java/google/registry/testing/FakeKeyringModule.java b/core/src/test/java/google/registry/testing/FakeKeyringModule.java index 88f19838f..2ea687f77 100644 --- a/core/src/test/java/google/registry/testing/FakeKeyringModule.java +++ b/core/src/test/java/google/registry/testing/FakeKeyringModule.java @@ -19,6 +19,7 @@ import static google.registry.keyring.api.PgpHelper.KeyRequirement.SIGN; import static google.registry.testing.TestDataHelper.loadBytes; import static google.registry.testing.TestDataHelper.loadFile; +import com.google.common.collect.ImmutableList; import com.google.common.io.ByteSource; import dagger.Module; import dagger.Provides; @@ -57,7 +58,8 @@ public final class FakeKeyringModule { private static final String MARKSDB_SMDRL_LOGIN_AND_PASSWORD = "smdrl:yolo"; private static final String BSA_API_KEY = "bsaapikey"; private static final String SQL_PRIMARY_CONNECTION = "project:primary-region:primary-name"; - private static final String SQL_REPLICA_CONNECTION = "project:replica-region:replica-name"; + private static final String SQL_REPLICA_CONNECTION_1 = "project:replica-region:replica-name"; + private static final String SQL_REPLICA_CONNECTION_2 = "project:replica-region:replica-name-2"; @Provides public Keyring get() { @@ -160,7 +162,12 @@ public final class FakeKeyringModule { @Override public String getSqlReplicaConnectionName() { - return SQL_REPLICA_CONNECTION; + return SQL_REPLICA_CONNECTION_1; + } + + @Override + public ImmutableList getSqlReplicaConnectionNames() { + return ImmutableList.of(SQL_REPLICA_CONNECTION_1, SQL_REPLICA_CONNECTION_2); } @Override