diff --git a/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java b/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
index d4720028d..de89ffe95 100644
--- a/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
+++ b/core/src/main/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
@@ -15,20 +15,25 @@
package google.registry.tools.server;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.Iterables.getLast;
import static google.registry.dns.DnsUtils.requestDomainDnsRefresh;
import static google.registry.model.tld.Tlds.assertTldsExist;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.request.RequestParameters.PARAM_TLDS;
+import static google.registry.util.DateTimeUtils.END_OF_TIME;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger;
import google.registry.request.Action;
import google.registry.request.Parameter;
import google.registry.request.Response;
import google.registry.request.auth.Auth;
-import google.registry.util.Clock;
+import java.util.Optional;
import java.util.Random;
import javax.inject.Inject;
+import org.apache.arrow.util.VisibleForTesting;
import org.apache.http.HttpStatus;
import org.joda.time.Duration;
@@ -43,10 +48,8 @@ import org.joda.time.Duration;
* run internally, or by pretending to be internal by setting the X-AppEngine-QueueName header,
* which only admin users can do.
*
- *
You must pass in a number of {@code smearMinutes} as a URL parameter so that the DNS queue
- * doesn't get overloaded. A rough rule of thumb for Cloud DNS is 1 minute per every 1,000 domains.
- * This smears the updates out over the next N minutes. For small TLDs consisting of fewer than
- * 1,000 domains, passing in 1 is fine (which will execute all the updates immediately).
+ *
You may pass in a {@code batchSize} for the batched read of domains from the database. This is
+ * recommended to be somewhere between 200 and 500. The default value is 250.
*/
@Action(
service = Action.Service.TOOLS,
@@ -56,47 +59,78 @@ public class RefreshDnsForAllDomainsAction implements Runnable {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
- @Inject Response response;
+ private static final int DEFAULT_BATCH_SIZE = 250;
+
+ private final Response response;
+ private final ImmutableSet tlds;
+
+ // Recommended value for batch size is between 200 and 500
+ private final int batchSize;
+ private final Random random;
@Inject
- @Parameter(PARAM_TLDS)
- ImmutableSet tlds;
-
- @Inject
- @Parameter("smearMinutes")
- int smearMinutes;
-
- @Inject Clock clock;
- @Inject Random random;
-
- @Inject
- RefreshDnsForAllDomainsAction() {}
+ RefreshDnsForAllDomainsAction(
+ Response response,
+ @Parameter(PARAM_TLDS) ImmutableSet tlds,
+ @Parameter("batchSize") Optional batchSize,
+ Random random) {
+ this.response = response;
+ this.tlds = tlds;
+ this.batchSize = batchSize.orElse(DEFAULT_BATCH_SIZE);
+ this.random = random;
+ }
@Override
public void run() {
assertTldsExist(tlds);
- checkArgument(smearMinutes > 0, "Must specify a positive number of smear minutes");
- tm().transact(
- () ->
- tm().query(
- "SELECT domainName FROM Domain "
- + "WHERE tld IN (:tlds) "
- + "AND deletionTime > :now",
- String.class)
- .setParameter("tlds", tlds)
- .setParameter("now", clock.nowUtc())
- .getResultStream()
- .forEach(
- domainName -> {
- try {
- // Smear the task execution time over the next N minutes.
- requestDomainDnsRefresh(
- domainName, Duration.standardMinutes(random.nextInt(smearMinutes)));
- } catch (Throwable t) {
- logger.atSevere().withCause(t).log(
- "Error while enqueuing DNS refresh for domain '%s'.", domainName);
- response.setStatus(HttpStatus.SC_INTERNAL_SERVER_ERROR);
- }
- }));
+ checkArgument(batchSize > 0, "Must specify a positive number for batch size");
+ int smearMinutes = tm().transact(this::calculateSmearMinutes);
+ ImmutableList previousBatch = ImmutableList.of("");
+ do {
+ String lastInPreviousBatch = getLast(previousBatch);
+ previousBatch = tm().transact(() -> refreshBatch(lastInPreviousBatch, smearMinutes));
+ } while (previousBatch.size() == batchSize);
+ }
+
+ /**
+ * Calculates the number of smear minutes to enqueue refreshes so that the DNS queue does not get
+ * overloaded.
+ */
+ private int calculateSmearMinutes() {
+ Long activeDomains =
+ tm().query(
+ "SELECT COUNT(*) FROM Domain WHERE tld IN (:tlds) AND deletionTime = :endOfTime",
+ Long.class)
+ .setParameter("tlds", tlds)
+ .setParameter("endOfTime", END_OF_TIME)
+ .getSingleResult();
+ return Math.max(activeDomains.intValue() / 1000, 1);
+ }
+
+ private ImmutableList getBatch(String lastInPreviousBatch) {
+ return tm().query(
+ "SELECT domainName FROM Domain WHERE tld IN (:tlds) AND"
+ + " deletionTime = :endOfTime AND domainName >"
+ + " :lastInPreviousBatch ORDER BY domainName ASC",
+ String.class)
+ .setParameter("tlds", tlds)
+ .setParameter("endOfTime", END_OF_TIME)
+ .setParameter("lastInPreviousBatch", lastInPreviousBatch)
+ .setMaxResults(batchSize)
+ .getResultStream()
+ .collect(toImmutableList());
+ }
+
+ @VisibleForTesting
+ ImmutableList refreshBatch(String lastInPreviousBatch, int smearMinutes) {
+ ImmutableList domainBatch = getBatch(lastInPreviousBatch);
+ try {
+ // Smear the task execution time over the next N minutes.
+ requestDomainDnsRefresh(domainBatch, Duration.standardMinutes(random.nextInt(smearMinutes)));
+ } catch (Throwable t) {
+ logger.atSevere().withCause(t).log("Error while enqueuing DNS refresh batch");
+ response.setStatus(HttpStatus.SC_OK);
+ }
+ return domainBatch;
}
}
diff --git a/core/src/main/java/google/registry/tools/server/ToolsServerModule.java b/core/src/main/java/google/registry/tools/server/ToolsServerModule.java
index dc617679e..a85b3c10f 100644
--- a/core/src/main/java/google/registry/tools/server/ToolsServerModule.java
+++ b/core/src/main/java/google/registry/tools/server/ToolsServerModule.java
@@ -16,6 +16,7 @@ package google.registry.tools.server;
import static com.google.common.base.Strings.emptyToNull;
import static google.registry.request.RequestParameters.extractIntParameter;
+import static google.registry.request.RequestParameters.extractOptionalIntParameter;
import static google.registry.request.RequestParameters.extractOptionalParameter;
import static google.registry.request.RequestParameters.extractRequiredParameter;
@@ -76,8 +77,8 @@ public class ToolsServerModule {
}
@Provides
- @Parameter("smearMinutes")
- static int provideSmearMinutes(HttpServletRequest req) {
- return extractIntParameter(req, "smearMinutes");
+ @Parameter("batchSize")
+ static Optional provideBatchSize(HttpServletRequest req) {
+ return extractOptionalIntParameter(req, "batchSize");
}
}
diff --git a/core/src/test/java/google/registry/testing/DatabaseHelper.java b/core/src/test/java/google/registry/testing/DatabaseHelper.java
index 07cb0a5fa..c0a5a7947 100644
--- a/core/src/test/java/google/registry/testing/DatabaseHelper.java
+++ b/core/src/test/java/google/registry/testing/DatabaseHelper.java
@@ -1357,5 +1357,16 @@ public final class DatabaseHelper {
.isEqualTo(1);
}
+ public static void assertDnsRequestsWithRequestTime(DateTime requestTime, int numOfDomains) {
+ assertThat(
+ tm().transact(
+ () ->
+ tm().createQueryComposer(DnsRefreshRequest.class)
+ .where("type", EQ, DnsUtils.TargetType.DOMAIN)
+ .where("requestTime", EQ, requestTime)
+ .count()))
+ .isEqualTo(numOfDomains);
+ }
+
private DatabaseHelper() {}
}
diff --git a/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java b/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
index 0892169fb..9d6c4b15c 100644
--- a/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
+++ b/core/src/test/java/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
@@ -15,18 +15,24 @@
package google.registry.tools.server;
import static com.google.common.truth.Truth.assertThat;
+import static google.registry.persistence.transaction.QueryComposer.Comparator.EQ;
+import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
+import static google.registry.testing.DatabaseHelper.assertDnsRequestsWithRequestTime;
import static google.registry.testing.DatabaseHelper.assertDomainDnsRequestWithRequestTime;
import static google.registry.testing.DatabaseHelper.assertNoDnsRequestsExcept;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
-import static org.junit.jupiter.api.Assertions.assertThrows;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import google.registry.dns.DnsUtils;
+import google.registry.model.common.DnsRefreshRequest;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse;
+import java.util.Optional;
import java.util.Random;
import org.joda.time.DateTime;
import org.junit.jupiter.api.BeforeEach;
@@ -46,20 +52,16 @@ public class RefreshDnsForAllDomainsActionTest {
@BeforeEach
void beforeEach() {
- action = new RefreshDnsForAllDomainsAction();
- action.smearMinutes = 1;
- action.random = new Random();
- action.random.setSeed(123L);
- action.clock = clock;
- action.response = response;
createTld("bar");
+ action =
+ new RefreshDnsForAllDomainsAction(
+ response, ImmutableSet.of("bar"), Optional.of(10), new Random());
}
@Test
void test_runAction_successfullyEnqueuesDnsRefreshes() throws Exception {
persistActiveDomain("foo.bar");
persistActiveDomain("low.bar");
- action.tlds = ImmutableSet.of("bar");
action.run();
assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc());
assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc());
@@ -69,18 +71,27 @@ public class RefreshDnsForAllDomainsActionTest {
void test_runAction_smearsOutDnsRefreshes() throws Exception {
persistActiveDomain("foo.bar");
persistActiveDomain("low.bar");
- action.tlds = ImmutableSet.of("bar");
- action.smearMinutes = 1000;
- action.run();
- assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc().plusMinutes(450));
- assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc().plusMinutes(782));
+ // Set batch size to 1 since each batch will be enqueud at the same time
+ action =
+ new RefreshDnsForAllDomainsAction(
+ response, ImmutableSet.of("bar"), Optional.of(1), new Random());
+ tm().transact(() -> action.refreshBatch("", 1000));
+ tm().transact(() -> action.refreshBatch("", 1000));
+ ImmutableList refreshRequests =
+ tm().transact(
+ () ->
+ tm().createQueryComposer(DnsRefreshRequest.class)
+ .where("type", EQ, DnsUtils.TargetType.DOMAIN)
+ .list());
+ assertThat(refreshRequests.size()).isEqualTo(2);
+ assertThat(refreshRequests.get(0).getRequestTime())
+ .isNotEqualTo(refreshRequests.get(1).getRequestTime());
}
@Test
void test_runAction_doesntRefreshDeletedDomain() throws Exception {
persistActiveDomain("foo.bar");
persistDeletedDomain("deleted.bar", clock.nowUtc().minusYears(1));
- action.tlds = ImmutableSet.of("bar");
action.run();
assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc());
assertNoDnsRequestsExcept("foo.bar");
@@ -92,7 +103,6 @@ public class RefreshDnsForAllDomainsActionTest {
persistActiveDomain("foo.bar");
persistActiveDomain("low.bar");
persistActiveDomain("ignore.baz");
- action.tlds = ImmutableSet.of("bar");
action.run();
assertDomainDnsRequestWithRequestTime("foo.bar", clock.nowUtc());
assertDomainDnsRequestWithRequestTime("low.bar", clock.nowUtc());
@@ -100,13 +110,11 @@ public class RefreshDnsForAllDomainsActionTest {
}
@Test
- void test_smearMinutesMustBeSpecified() {
- action.tlds = ImmutableSet.of("bar");
- action.smearMinutes = 0;
- IllegalArgumentException thrown =
- assertThrows(IllegalArgumentException.class, () -> action.run());
- assertThat(thrown)
- .hasMessageThat()
- .isEqualTo("Must specify a positive number of smear minutes");
+ void test_successfullyBatchesNames() {
+ for (int i = 0; i <= 10; i++) {
+ persistActiveDomain(String.format("test%s.bar", i));
+ }
+ action.run();
+ assertDnsRequestsWithRequestTime(clock.nowUtc(), 11);
}
}