1
0
mirror of https://github.com/google/nomulus synced 2025-12-23 14:25:44 +00:00

Add cache for User entities in OIDC auth flow (#2822)

* Add cache for User entities in OIDC auth flow

* refactor: Address review feedback

- Refactor database call into a single, reusable method
- Increase the default cache size to 200
- Remove .recordStats() and using spy for testing
- Split unit tests into separate implementation test that use Mockito spies instead of checking internal cache stats
This commit is contained in:
Nilay Shah
2025-09-12 13:13:32 +05:30
committed by GitHub
parent 732c30b359
commit 06299ccb86
5 changed files with 170 additions and 3 deletions

View File

@@ -1591,6 +1591,26 @@ public final class RegistryConfig {
return CONFIG_SETTINGS.get().caching.eppResourceMaxCachedEntries;
}
/** Returns if we have enabled caching for User Authentication */
public static boolean getUserAuthCachingEnabled() {
return CONFIG_SETTINGS.get().caching.userAuthCachingEnabled;
}
@VisibleForTesting
public static void overrideIsUserAuthCachingEnabledForTesting(boolean enabled) {
CONFIG_SETTINGS.get().caching.userAuthCachingEnabled = enabled;
}
/** Returns the expiry duration for the user authentication cache. */
public static java.time.Duration getUserAuthCachingDuration() {
return java.time.Duration.ofSeconds(CONFIG_SETTINGS.get().caching.userAuthCachingSeconds);
}
/** Returns the maximum number of entries in user authentication cache. */
public static int getUserAuthMaxCachedEntries() {
return CONFIG_SETTINGS.get().caching.userAuthMaxCachedEntries;
}
/** Returns the amount of time that a particular claims list should be cached. */
public static java.time.Duration getClaimsListCacheDuration() {
return java.time.Duration.ofSeconds(CONFIG_SETTINGS.get().caching.claimsListCachingSeconds);

View File

@@ -161,6 +161,9 @@ public class RegistryConfigSettings {
public int eppResourceCachingSeconds;
public int eppResourceMaxCachedEntries;
public int claimsListCachingSeconds;
public boolean userAuthCachingEnabled;
public int userAuthCachingSeconds;
public int userAuthMaxCachedEntries;
}
/** Configuration for ICANN monthly reporting. */

View File

@@ -326,6 +326,20 @@ caching:
# long duration is acceptable because claims lists don't change frequently.
claimsListCachingSeconds: 21600 # six hours
#-- User Authentication Cache Settings --#
# Whether to cache User objects during OIDC token authentication to reduce database load.
# This helps mitigate high QPS from frequent hello commands and session-less requests.
userAuthCachingEnabled: true
# The duration in seconds for which a User object is cached after being loaded.
# A short duration is recommended to avoid stale data.
userAuthCachingSeconds: 60
# The maximum number of User objects to store in the cache per pod.
# This helps limit the memory footprint of the cache.
userAuthMaxCachedEntries: 200
# Note: Only allowedServiceAccountEmails and oauthClientId should be configured.
# Other fields are related to OAuth-based authentication and will be removed.
auth:

View File

@@ -15,14 +15,20 @@
package google.registry.request.auth;
import static com.google.common.base.Preconditions.checkState;
import static google.registry.config.RegistryConfig.getUserAuthCachingDuration;
import static google.registry.config.RegistryConfig.getUserAuthCachingEnabled;
import static google.registry.config.RegistryConfig.getUserAuthMaxCachedEntries;
import static google.registry.model.CacheUtils.newCacheBuilder;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.google.api.client.json.webtoken.JsonWebSignature;
import com.google.auth.oauth2.TokenVerifier.VerificationException;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger;
import google.registry.config.RegistryConfig;
import google.registry.config.RegistryConfig.Config;
import google.registry.model.console.User;
import google.registry.persistence.VKey;
@@ -71,6 +77,40 @@ public abstract class OidcTokenAuthenticationMechanism implements Authentication
this.tokenVerifier = tokenVerifier;
}
/**
* An in-memory cache for User entities, built using the project's standard utility.
*
* <p>This cache reduces database load by temporarily storing User objects after they are fetched.
* It is configured to cache negative results (i.e., when a user is not found) to prevent repeated
* lookups for invalid users. The cache's behavior (enabled, expiry, size) is controlled by
* settings in {@link RegistryConfig}.
*/
@VisibleForTesting
static LoadingCache<String, Optional<User>> userCache =
newCacheBuilder(getUserAuthCachingDuration())
.maximumSize(getUserAuthMaxCachedEntries())
.build(OidcTokenAuthenticationMechanism::loadUser);
/**
* A loader function that defines how to fetch a User from the database on a cache miss.
*
* <p>This is the single point of entry to the database for this authentication flow. It will only
* be invoked by the cache when a requested user is not already in memory.
*/
@VisibleForTesting
static Optional<User> loadUser(String email) {
VKey<User> userVKey = VKey.create(User.class, email);
return tm().transact(() -> tm().loadByKeyIfPresent(userVKey));
}
@VisibleForTesting
public static void setCacheForTesting(LoadingCache<String, Optional<User>> cache) {
checkState(
RegistryEnvironment.get() == RegistryEnvironment.UNITTEST,
"Cannot set cache outside of a test environment");
OidcTokenAuthenticationMechanism.userCache = cache;
}
@Override
public AuthResult authenticate(HttpServletRequest request) {
if (RegistryEnvironment.get().equals(RegistryEnvironment.UNITTEST)
@@ -112,8 +152,15 @@ public abstract class OidcTokenAuthenticationMechanism implements Authentication
logger.atInfo().log("No email address from the OIDC token:\n%s", token.getPayload());
return AuthResult.NOT_AUTHENTICATED;
}
Optional<User> maybeUser =
tm().transact(() -> tm().loadByKeyIfPresent(VKey.create(User.class, email)));
Optional<User> maybeUser;
if (getUserAuthCachingEnabled()) {
// If caching is ON, use the cache.
maybeUser = userCache.get(email);
} else {
// If caching is OFF, fall back to the original direct database call.
maybeUser = loadUser(email);
}
stopwatch.tick("OidcTokenAuthenticationMechanism maybeUser loaded");
if (maybeUser.isPresent()) {
return AuthResult.createUser(maybeUser.get());

View File

@@ -16,13 +16,20 @@ package google.registry.request.auth;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.config.RegistryConfig.getUserAuthCachingDuration;
import static google.registry.config.RegistryConfig.getUserAuthMaxCachedEntries;
import static google.registry.request.auth.AuthModule.BEARER_PREFIX;
import static google.registry.request.auth.AuthModule.IAP_HEADER_NAME;
import static google.registry.testing.DatabaseHelper.createAdminUser;
import static google.registry.testing.DatabaseHelper.persistResource;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.google.api.client.googleapis.auth.oauth2.GoogleIdToken.Payload;
import com.google.api.client.json.webtoken.JsonWebSignature;
import com.google.api.client.json.webtoken.JsonWebSignature.Header;
@@ -33,7 +40,9 @@ import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.config.CredentialModule.ApplicationDefaultCredential;
import google.registry.config.RegistryConfig;
import google.registry.config.RegistryConfig.Config;
import google.registry.model.CacheUtils;
import google.registry.model.console.GlobalRole;
import google.registry.model.console.User;
import google.registry.model.console.UserRoles;
@@ -44,6 +53,7 @@ import google.registry.request.auth.OidcTokenAuthenticationMechanism.RegularOidc
import google.registry.util.GoogleCredentialsBundle;
import jakarta.inject.Singleton;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Optional;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -54,6 +64,9 @@ public class OidcTokenAuthenticationMechanismTest {
private static final String rawToken = "this-token";
private static final String email = "user@email.test";
private static final String unknownEmail = "bad-guy@evil.real";
private static final String gaiaId = "gaia-id";
private static final ImmutableSet<String> serviceAccounts =
ImmutableSet.of("service@email.test", "email@service.goog");
@@ -75,6 +88,12 @@ public class OidcTokenAuthenticationMechanismTest {
@BeforeEach
void beforeEach() throws Exception {
// 1. Create a brand new cache.
LoadingCache<String, Optional<User>> testCache =
CacheUtils.newCacheBuilder(getUserAuthCachingDuration())
.maximumSize(getUserAuthMaxCachedEntries())
.build(OidcTokenAuthenticationMechanism::loadUser);
OidcTokenAuthenticationMechanism.setCacheForTesting(testCache);
payload.setEmail(email);
payload.setSubject(gaiaId);
user = createAdminUser(email);
@@ -154,7 +173,7 @@ public class OidcTokenAuthenticationMechanismTest {
@Test
void testAuthenticate_unknownEmailAddress() throws Exception {
payload.setEmail("bad-guy@evil.real");
payload.setEmail(unknownEmail);
authResult = authenticationMechanism.authenticate(request);
assertThat(authResult).isEqualTo(AuthResult.NOT_AUTHENTICATED);
}
@@ -189,6 +208,62 @@ public class OidcTokenAuthenticationMechanismTest {
authenticationMechanism = component.regularOidcAuthenticationMechanism();
}
@Test
void testAuthenticate_ExistentUser_isCached() {
// Arrange: Create a spy of the actual cache object.
// A spy calls the real methods of the object while allowing us to verify interactions.
LoadingCache<String, Optional<User>> spiedCache =
spy(OidcTokenAuthenticationMechanism.userCache);
OidcTokenAuthenticationMechanism.setCacheForTesting(spiedCache);
// Act: Call the authenticate method.
authenticationMechanism.authenticate(request);
// Assert: Verify that the cache's "get" method was called exactly once.
// This confirms the cache is being used without checking its internal stats.
verify(spiedCache).get(email);
}
@Test
void testAuthenticate_nonExistentUser_isCached() {
// Arrange: Use an email that is not in the test database.
payload.setEmail(unknownEmail);
LoadingCache<String, Optional<User>> spiedCache =
spy(OidcTokenAuthenticationMechanism.userCache);
OidcTokenAuthenticationMechanism.setCacheForTesting(spiedCache);
// Act: Call the authenticate method.
authenticationMechanism.authenticate(request);
// Assert: Verify that the cache's "get" method was called for the unverified email.
// This confirms that we attempted to look up the unknown user in the cache.
verify(spiedCache).get(unknownEmail);
}
@Test
void testAuthenticate_whenCacheIsDisabled_cacheIsNotUsed() {
// Arrange: Explicitly disable the cache and create a spy.
RegistryConfig.overrideIsUserAuthCachingEnabledForTesting(false);
LoadingCache<String, Optional<User>> spiedCache =
spy(OidcTokenAuthenticationMechanism.userCache);
OidcTokenAuthenticationMechanism.setCacheForTesting(spiedCache);
// Act: Authenticate the user.
AuthResult authResult = authenticationMechanism.authenticate(request);
// Assert: The authentication should still succeed because the code falls back
// to the direct database call.
assertThat(authResult.isAuthenticated()).isTrue();
// Assert: Crucially, verify that the cache's "get" method was NEVER called.
// This proves the cache was correctly bypassed.
verify(spiedCache, never()).get(any(String.class));
// Teardown: Restore the default setting for other tests.
RegistryConfig.overrideIsUserAuthCachingEnabledForTesting(true);
}
@Singleton
@Component(modules = {AuthModule.class, TestModule.class})
interface TestComponent {
@@ -234,4 +309,12 @@ public class OidcTokenAuthenticationMechanismTest {
return GoogleCredentialsBundle.create(GoogleCredentials.newBuilder().build());
}
}
private void reinitializeCache() {
OidcTokenAuthenticationMechanism.userCache =
CacheUtils.newCacheBuilder(getUserAuthCachingDuration())
.maximumSize(getUserAuthMaxCachedEntries())
.recordStats()
.build(OidcTokenAuthenticationMechanism::loadUser);
}
}