diff --git a/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java b/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java index d4720028d..de89ffe95 100644 --- a/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java +++ b/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java @@ -15,20 +15,25 @@ package google.registry.tools.server; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getLast; import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; import static google.registry.model.tld.Tlds.assertTldsExist; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.request.RequestParameters.PARAM_TLDS; +import static google.registry.util.DateTimeUtils.END_OF_TIME; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; import google.registry.request.Action; import google.registry.request.Parameter; import google.registry.request.Response; import google.registry.request.auth.Auth; -import google.registry.util.Clock; +import java.util.Optional; import java.util.Random; import javax.inject.Inject; +import org.apache.arrow.util.VisibleForTesting; import org.apache.http.HttpStatus; import org.joda.time.Duration; @@ -43,10 +48,8 @@ import org.joda.time.Duration; * run internally, or by pretending to be internal by setting the X-AppEngine-QueueName header, * which only admin users can do. * - *

You must pass in a number of {@code smearMinutes} as a URL parameter so that the DNS queue - * doesn't get overloaded. A rough rule of thumb for Cloud DNS is 1 minute per every 1,000 domains. - * This smears the updates out over the next N minutes. For small TLDs consisting of fewer than - * 1,000 domains, passing in 1 is fine (which will execute all the updates immediately). + *

You may pass in a {@code batchSize} for the batched read of domains from the database. This is + * recommended to be somewhere between 200 and 500. The default value is 250. */ @Action( service = Action.Service.TOOLS, @@ -56,47 +59,78 @@ public class RefreshDnsForAllDomainsAction implements Runnable { private static final FluentLogger logger = FluentLogger.forEnclosingClass(); - @Inject Response response; + private static final int DEFAULT_BATCH_SIZE = 250; + + private final Response response; + private final ImmutableSet tlds; + + // Recommended value for batch size is between 200 and 500 + private final int batchSize; + private final Random random; @Inject - @Parameter(PARAM_TLDS) - ImmutableSet tlds; - - @Inject - @Parameter("smearMinutes") - int smearMinutes; - - @Inject Clock clock; - @Inject Random random; - - @Inject - RefreshDnsForAllDomainsAction() {} + RefreshDnsForAllDomainsAction( + Response response, + @Parameter(PARAM_TLDS) ImmutableSet tlds, + @Parameter("batchSize") Optional batchSize, + Random random) { + this.response = response; + this.tlds = tlds; + this.batchSize = batchSize.orElse(DEFAULT_BATCH_SIZE); + this.random = random; + } @Override public void run() { assertTldsExist(tlds); - checkArgument(smearMinutes > 0, "Must specify a positive number of smear minutes"); - tm().transact( - () -> - tm().query( - "SELECT domainName FROM Domain " - + "WHERE tld IN (:tlds) " - + "AND deletionTime > :now", - String.class) - .setParameter("tlds", tlds) - .setParameter("now", clock.nowUtc()) - .getResultStream() - .forEach( - domainName -> { - try { - // Smear the task execution time over the next N minutes. - requestDomainDnsRefresh( - domainName, Duration.standardMinutes(random.nextInt(smearMinutes))); - } catch (Throwable t) { - logger.atSevere().withCause(t).log( - "Error while enqueuing DNS refresh for domain '%s'.", domainName); - response.setStatus(HttpStatus.SC_INTERNAL_SERVER_ERROR); - } - })); + checkArgument(batchSize > 0, "Must specify a positive number for batch size"); + int smearMinutes = tm().transact(this::calculateSmearMinutes); + ImmutableList previousBatch = ImmutableList.of(""); + do { + String lastInPreviousBatch = getLast(previousBatch); + previousBatch = tm().transact(() -> refreshBatch(lastInPreviousBatch, smearMinutes)); + } while (previousBatch.size() == batchSize); + } + + /** + * Calculates the number of smear minutes to enqueue refreshes so that the DNS queue does not get + * overloaded. + */ + private int calculateSmearMinutes() { + Long activeDomains = + tm().query( + "SELECT COUNT(*) FROM Domain WHERE tld IN (:tlds) AND deletionTime = :endOfTime", + Long.class) + .setParameter("tlds", tlds) + .setParameter("endOfTime", END_OF_TIME) + .getSingleResult(); + return Math.max(activeDomains.intValue() / 1000, 1); + } + + private ImmutableList getBatch(String lastInPreviousBatch) { + return tm().query( + "SELECT domainName FROM Domain WHERE tld IN (:tlds) AND" + + " deletionTime = :endOfTime AND domainName >" + + " :lastInPreviousBatch ORDER BY domainName ASC", + String.class) + .setParameter("tlds", tlds) + .setParameter("endOfTime", END_OF_TIME) + .setParameter("lastInPreviousBatch", lastInPreviousBatch) + .setMaxResults(batchSize) + .getResultStream() + .collect(toImmutableList()); + } + + @VisibleForTesting + ImmutableList refreshBatch(String lastInPreviousBatch, int smearMinutes) { + ImmutableList domainBatch = getBatch(lastInPreviousBatch); + try { + // Smear the task execution time over the next N minutes. + requestDomainDnsRefresh(domainBatch, Duration.standardMinutes(random.nextInt(smearMinutes))); + } catch (Throwable t) { + logger.atSevere().withCause(t).log("Error while enqueuing DNS refresh batch"); + response.setStatus(HttpStatus.SC_OK); + } + return domainBatch; } } diff --git a/core/src/main/java/google/registry/tools/server/ToolsServerModule.java b/core/src/main/java/google/registry/tools/server/ToolsServerModule.java index dc617679e..a85b3c10f 100644 --- a/core/src/main/java/google/registry/tools/server/ToolsServerModule.java +++ b/core/src/main/java/google/registry/tools/server/ToolsServerModule.java @@ -16,6 +16,7 @@ package google.registry.tools.server; import static com.google.common.base.Strings.emptyToNull; import static google.registry.request.RequestParameters.extractIntParameter; +import static google.registry.request.RequestParameters.extractOptionalIntParameter; import static google.registry.request.RequestParameters.extractOptionalParameter; import static google.registry.request.RequestParameters.extractRequiredParameter; @@ -76,8 +77,8 @@ public class ToolsServerModule { } @Provides - @Parameter("smearMinutes") - static int provideSmearMinutes(HttpServletRequest req) { - return extractIntParameter(req, "smearMinutes"); + @Parameter("batchSize") + static Optional provideBatchSize(HttpServletRequest req) { + return extractOptionalIntParameter(req, "batchSize"); } } diff --git a/core/src/test/java/google/registry/testing/DatabaseHelper.java b/core/src/test/java/google/registry/testing/DatabaseHelper.java index 07cb0a5fa..c0a5a7947 100644 --- a/core/src/test/java/google/registry/testing/DatabaseHelper.java +++ b/core/src/test/java/google/registry/testing/DatabaseHelper.java @@ -1357,5 +1357,16 @@ public final class DatabaseHelper { .isEqualTo(1); } + public static void assertDnsRequestsWithRequestTime(DateTime requestTime, int numOfDomains) { + assertThat( + tm().transact( + () -> + tm().createQueryComposer(DnsRefreshRequest.class) + .where("type", EQ, DnsUtils.TargetType.DOMAIN) + .where("requestTime", EQ, requestTime) + .count())) + .isEqualTo(numOfDomains); + } + private DatabaseHelper() {} } diff --git a/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java b/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java index 0892169fb..9d6c4b15c 100644 --- a/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java +++ b/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java @@ -15,18 +15,24 @@ package google.registry.tools.server; import static com.google.common.truth.Truth.assertThat; +import static google.registry.persistence.transaction.QueryComposer.Comparator.EQ; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.testing.DatabaseHelper.assertDnsRequestsWithRequestTime; import static google.registry.testing.DatabaseHelper.assertDomainDnsRequestWithRequestTime; import static google.registry.testing.DatabaseHelper.assertNoDnsRequestsExcept; import static google.registry.testing.DatabaseHelper.createTld; import static google.registry.testing.DatabaseHelper.persistActiveDomain; import static google.registry.testing.DatabaseHelper.persistDeletedDomain; -import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import google.registry.dns.DnsUtils; +import google.registry.model.common.DnsRefreshRequest; import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension; import google.registry.testing.FakeClock; import google.registry.testing.FakeResponse; +import java.util.Optional; import java.util.Random; import org.joda.time.DateTime; import org.junit.jupiter.api.BeforeEach; @@ -46,20 +52,16 @@ public class RefreshDnsForAllDomainsActionTest { @BeforeEach void beforeEach() { - action = new RefreshDnsForAllDomainsAction(); - action.smearMinutes = 1; - action.random = new Random(); - action.random.setSeed(123L); - action.clock = clock; - action.response = response; createTld("bar"); + action = + new RefreshDnsForAllDomainsAction( + response, ImmutableSet.of("bar"), Optional.of(10), new Random()); } @Test void test_runAction_successfullyEnqueuesDnsRefreshes() throws Exception { persistActiveDomain("foo.bar"); persistActiveDomain("low.bar"); - action.tlds = ImmutableSet.of("bar"); action.run(); assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc()); assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc()); @@ -69,18 +71,27 @@ public class RefreshDnsForAllDomainsActionTest { void test_runAction_smearsOutDnsRefreshes() throws Exception { persistActiveDomain("foo.bar"); persistActiveDomain("low.bar"); - action.tlds = ImmutableSet.of("bar"); - action.smearMinutes = 1000; - action.run(); - assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc().plusMinutes(450)); - assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc().plusMinutes(782)); + // Set batch size to 1 since each batch will be enqueud at the same time + action = + new RefreshDnsForAllDomainsAction( + response, ImmutableSet.of("bar"), Optional.of(1), new Random()); + tm().transact(() -> action.refreshBatch("", 1000)); + tm().transact(() -> action.refreshBatch("", 1000)); + ImmutableList refreshRequests = + tm().transact( + () -> + tm().createQueryComposer(DnsRefreshRequest.class) + .where("type", EQ, DnsUtils.TargetType.DOMAIN) + .list()); + assertThat(refreshRequests.size()).isEqualTo(2); + assertThat(refreshRequests.get(0).getRequestTime()) + .isNotEqualTo(refreshRequests.get(1).getRequestTime()); } @Test void test_runAction_doesntRefreshDeletedDomain() throws Exception { persistActiveDomain("foo.bar"); persistDeletedDomain("deleted.bar", clock.nowUtc().minusYears(1)); - action.tlds = ImmutableSet.of("bar"); action.run(); assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc()); assertNoDnsRequestsExcept("foo.bar"); @@ -92,7 +103,6 @@ public class RefreshDnsForAllDomainsActionTest { persistActiveDomain("foo.bar"); persistActiveDomain("low.bar"); persistActiveDomain("ignore.baz"); - action.tlds = ImmutableSet.of("bar"); action.run(); assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc()); assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc()); @@ -100,13 +110,11 @@ public class RefreshDnsForAllDomainsActionTest { } @Test - void test_smearMinutesMustBeSpecified() { - action.tlds = ImmutableSet.of("bar"); - action.smearMinutes = 0; - IllegalArgumentException thrown = - assertThrows(IllegalArgumentException.class, () -> action.run()); - assertThat(thrown) - .hasMessageThat() - .isEqualTo("Must specify a positive number of smear minutes"); + void test_successfullyBatchesNames() { + for (int i = 0; i <= 10; i++) { + persistActiveDomain(String.format("test%s.bar", i)); + } + action.run(); + assertDnsRequestsWithRequestTime(clock.nowUtc(), 11); } }