diff --git a/core/build.gradle b/core/build.gradle
index 8e0d28b23..b86ce81bf 100644
--- a/core/build.gradle
+++ b/core/build.gradle
@@ -58,6 +58,8 @@ def fragileTestPatterns = [
// Changes cache timeouts and for some reason appears to have contention
// with other tests.
"google/registry/whois/WhoisCommandFactoryTest.*",
+ // Breaks random other tests when running with standardTests.
+ "google/registry/bsa/UploadBsaUnavailableDomainsActionTest.*",
// Currently changes a global configuration parameter that for some reason
// results in timestamp inversions for other tests. TODO(mmuller): fix.
"google/registry/flows/host/HostInfoFlowTest.*",
diff --git a/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java b/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java
index c30ea788a..ea3bac08a 100644
--- a/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java
+++ b/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java
@@ -25,16 +25,16 @@ import static google.registry.request.Action.Method.GET;
import static google.registry.request.Action.Method.POST;
import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
import static java.nio.charset.StandardCharsets.US_ASCII;
+import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.cloud.storage.BlobId;
-import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Ordering;
import com.google.common.flogger.FluentLogger;
+import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
-import com.google.common.io.ByteSource;
import google.registry.bsa.api.BsaCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.gcs.GcsUtils;
@@ -47,10 +47,13 @@ import google.registry.request.auth.Auth;
import google.registry.util.Clock;
import jakarta.inject.Inject;
import jakarta.persistence.TypedQuery;
-import java.io.ByteArrayOutputStream;
+import java.io.BufferedInputStream;
import java.io.IOException;
+import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
import java.io.Writer;
import java.util.Optional;
import java.util.zip.GZIPOutputStream;
@@ -60,14 +63,17 @@ import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
+import okio.BufferedSink;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
import org.joda.time.DateTime;
/**
* Daily action that uploads unavailable domain names on applicable TLDs to BSA.
*
*
The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs.
- * The text is a newline-delimited list of punycoded fully qualified domain names, and contains all
- * domains on each TLD that are registered and/or reserved.
+ * The text is a newline-delimited list of punycoded fully qualified domain names with a trailing
+ * newline at the end, and contains all domains on each TLD that are registered and/or reserved.
*
*
The file is also uploaded to GCS to preserve it as a record for ourselves.
*/
@@ -118,7 +124,7 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
// TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit
// the run if the daily upload is already completed.
DateTime runTime = clock.nowUtc();
- String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime));
+ ImmutableSortedSet unavailableDomains = getUnavailableDomains(runTime);
if (unavailableDomains.isEmpty()) {
logger.atWarning().log("No unavailable domains found; terminating.");
emailSender.sendNotification(
@@ -136,12 +142,16 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
/** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */
- boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
+ boolean uploadToGcs(ImmutableSortedSet unavailableDomains, DateTime runTime) {
logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket);
BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime));
+ // `gcsUtils.openOutputStream` returns a buffered stream
try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId);
Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) {
- osWriter.write(unavailableDomains);
+ for (var domainName : unavailableDomains) {
+ osWriter.write(domainName);
+ osWriter.write("\n");
+ }
return true;
} catch (Exception e) {
logger.atSevere().withCause(e).log(
@@ -150,10 +160,14 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
}
- boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
+ boolean uploadToBsa(ImmutableSortedSet unavailableDomains, DateTime runTime) {
try {
- byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains);
- String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString();
+ Hasher sha512Hasher = Hashing.sha512().newHasher();
+ unavailableDomains.stream()
+ .map(name -> name + "\n")
+ .forEachOrdered(line -> sha512Hasher.putString(line, UTF_8));
+ String sha512Hash = sha512Hasher.hash().toString();
+
String filename = createFilename(runTime);
OkHttpClient client = new OkHttpClient().newBuilder().build();
@@ -169,7 +183,9 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
.addFormDataPart(
"file",
String.format("%s.gz", filename),
- RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream")))
+ new StreamingRequestBody(
+ gzippedStream(unavailableDomains),
+ MediaType.parse("application/octet-stream")))
.build();
Request request =
@@ -196,15 +212,6 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
}
- private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException {
- try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
- try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
- gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII));
- }
- return byteArrayOutputStream.toByteArray();
- }
- }
-
private static String createFilename(DateTime runTime) {
return String.format("unavailable_domains_%s.txt", runTime.toString());
}
@@ -280,4 +287,65 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
private static String toDomain(String domainLabel, Tld tld) {
return String.format("%s.%s", domainLabel, tld.getTldStr());
}
+
+ private InputStream gzippedStream(ImmutableSortedSet unavailableDomains)
+ throws IOException {
+ PipedInputStream inputStream = new PipedInputStream();
+ PipedOutputStream outputStream = new PipedOutputStream(inputStream);
+
+ new Thread(
+ () -> {
+ try {
+ gzipUnavailableDomains(outputStream, unavailableDomains);
+ } catch (Throwable e) {
+ logger.atSevere().withCause(e).log("Failed to gzip unavailable domains.");
+ try {
+ // This will cause the next read to throw an IOException.
+ inputStream.close();
+ } catch (IOException ignore) {
+ // Won't happen for `PipedInputStream.close()`
+ }
+ }
+ })
+ .start();
+
+ return inputStream;
+ }
+
+ private void gzipUnavailableDomains(
+ PipedOutputStream outputStream, ImmutableSortedSet unavailableDomains)
+ throws IOException {
+ // `GZIPOutputStream` is buffered.
+ try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) {
+ for (String name : unavailableDomains) {
+ var line = name + "\n";
+ gzipOutputStream.write(line.getBytes(US_ASCII));
+ }
+ }
+ }
+
+ private static class StreamingRequestBody extends RequestBody {
+ private final BufferedInputStream inputStream;
+ private final MediaType mediaType;
+
+ StreamingRequestBody(InputStream inputStream, MediaType mediaType) {
+ this.inputStream = new BufferedInputStream(inputStream);
+ this.mediaType = mediaType;
+ }
+
+ @Nullable
+ @Override
+ public MediaType contentType() {
+ return mediaType;
+ }
+
+ @Override
+ public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException {
+ byte[] buffer = new byte[2048];
+ int bytesRead;
+ while ((bytesRead = inputStream.read(buffer)) != -1) {
+ bufferedSink.write(buffer, 0, bytesRead);
+ }
+ }
+ }
}
diff --git a/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java b/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java
index 7c078e5d4..26a1831bb 100644
--- a/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java
+++ b/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java
@@ -20,13 +20,24 @@ import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
import static google.registry.testing.DatabaseHelper.persistReservedList;
import static google.registry.testing.DatabaseHelper.persistResource;
+import static google.registry.testing.LogsSubject.assertAboutLogs;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
+import static google.registry.util.NetworkUtils.pickUnusedPort;
import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.flogger.FluentLogger;
+import com.google.common.hash.Hashing;
+import com.google.common.io.ByteStreams;
+import com.google.common.net.HostAndPort;
+import com.google.common.testing.TestLogHandler;
+import com.google.gson.Gson;
import google.registry.bsa.api.BsaCredential;
import google.registry.gcs.GcsUtils;
import google.registry.model.tld.Tld;
@@ -35,9 +46,25 @@ import google.registry.model.tld.label.ReservedList;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.request.UrlConnectionService;
+import google.registry.server.Route;
+import google.registry.server.TestServer;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.annotation.MultipartConfig;
+import jakarta.servlet.http.HttpServlet;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import jakarta.servlet.http.Part;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.PrintWriter;
+import java.net.InetAddress;
+import java.util.Map;
import java.util.Optional;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.zip.GZIPInputStream;
import org.joda.time.DateTime;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -102,13 +129,112 @@ public class UploadBsaUnavailableDomainsActionTest {
BlobId existingFile =
BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc()));
String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8);
- assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld");
+ assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n");
assertThat(blockList).doesNotContain("not-blocked.tld");
// This test currently fails in the upload-to-bsa step.
verify(emailSender, times(1))
.sendNotification("BSA daily upload completed with errors", "Please see logs for details.");
+ }
- // TODO(mcilwain): Add test of BSA API upload as well.
+ @Test
+ void uploadToBsaTest() throws Exception {
+ TestLogHandler logHandler = new TestLogHandler();
+ Logger loggerToIntercept =
+ Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName());
+ loggerToIntercept.addHandler(logHandler);
+
+ persistActiveDomain("foobar.tld");
+ persistActiveDomain("ace.tld");
+ persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1));
+
+ var testServer = startTestServer();
+ action.apiUrl = testServer.getUrl("/upload").toURI().toString();
+ try {
+ action.run();
+ } finally {
+ testServer.stop();
+ }
+ String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n";
+ String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString();
+ String expectedResponse =
+ "Received response with code 200 from server: "
+ + String.format("Checksum: [%s]\n%s\n", checkSum, dataSent);
+ assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse);
+ verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", "");
+ }
+
+ private TestServer startTestServer() throws Exception {
+ TestServer testServer =
+ new TestServer(
+ HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()),
+ ImmutableMap.of(),
+ ImmutableList.of(Route.route("/upload", Servelet.class)));
+ testServer.start();
+ newSingleThreadExecutor()
+ .execute(
+ () -> {
+ try {
+ while (true) {
+ testServer.process();
+ }
+ } catch (InterruptedException e) {
+ // Expected
+ }
+ });
+ return testServer;
+ }
+
+ @MultipartConfig(
+ location = "", // Directory for storing uploaded files. Use default when blank
+ maxFileSize = 10485760L, // 10MB
+ maxRequestSize = 20971520L, // 20MB
+ fileSizeThreshold = 1048576 // Save in memory if file size < 1MB
+ )
+ public static class Servelet extends HttpServlet {
+ private static final FluentLogger logger = FluentLogger.forEnclosingClass();
+
+ @Override
+ protected void doPost(HttpServletRequest req, HttpServletResponse resp)
+ throws ServletException, IOException {
+ String checkSum = null;
+ String content = null;
+ try {
+ for (Part part : req.getParts()) {
+ switch (part.getName()) {
+ case "zone" -> checkSum = readChecksum(part);
+ case "file" -> content = readGzipped(part);
+ }
+ }
+ } catch (Exception e) {
+ logger.atInfo().withCause(e).log("");
+ }
+ int status = checkSum == null || content == null ? 400 : 200;
+ resp.setStatus(status);
+ resp.setContentType("text/plain");
+ try (PrintWriter writer = resp.getWriter()) {
+ writer.printf("Checksum: [%s]\n%s\n", checkSum, content);
+ }
+ }
+
+ private String readChecksum(Part part) {
+ try (InputStream is = part.getInputStream()) {
+ return new Gson()
+ .fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class)
+ .getOrDefault("checkSum", "Not found")
+ .toString();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private String readGzipped(Part part) {
+ try (InputStream is = part.getInputStream();
+ GZIPInputStream gis = new GZIPInputStream(is)) {
+ return new String(ByteStreams.toByteArray(gis), UTF_8);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
}
}
diff --git a/core/src/test/java/google/registry/server/TestServer.java b/core/src/test/java/google/registry/server/TestServer.java
index 2e457b4da..c37a15358 100644
--- a/core/src/test/java/google/registry/server/TestServer.java
+++ b/core/src/test/java/google/registry/server/TestServer.java
@@ -26,10 +26,16 @@ import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import google.registry.util.RegistryEnvironment;
import google.registry.util.UrlChecker;
+import jakarta.servlet.MultipartConfigElement;
+import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet;
+import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
+import java.nio.file.Files;
import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.FutureTask;
@@ -49,6 +55,8 @@ import org.eclipse.jetty.server.ServerConnector;
* {@link #stop()} methods. However, a {@link #process()} method was added, which is used to process
* requests made to servlets (not static files) in the calling thread.
*
+ * A servlet that expects multi-part requests should be annotated with {@link MultipartConfig}.
+ *
*
Note: This server is intended for development purposes. For the love all that is good,
* do not make this public-facing.
*
@@ -70,6 +78,7 @@ public final class TestServer {
private final HostAndPort urlAddress;
private final Server server = new Server();
private final BlockingQueue> requestQueue = new LinkedBlockingDeque<>();
+ private List multiPartTmpDirs = new ArrayList<>();
/**
* Creates a new instance, but does not begin serving.
@@ -134,6 +143,13 @@ public final class TestServer {
},
SHUTDOWN_TIMEOUT_MS,
TimeUnit.MILLISECONDS);
+ for (var dir : multiPartTmpDirs) {
+ try {
+ Files.delete(dir);
+ } catch (Exception e) {
+ // Ignore
+ }
+ }
} catch (Exception e) {
throwIfUnchecked(e);
throw new RuntimeException(e);
@@ -161,7 +177,27 @@ public final class TestServer {
StaticResourceServlet.configureServletHolder(holder, runfile.getKey(), runfile.getValue());
}
for (Route route : routes) {
- context.addServlet(wrapServlet(route.servletClass()), route.path());
+ holder = context.addServlet(wrapServlet(route.servletClass()), route.path());
+ MultipartConfig multipartConfig = route.servletClass().getAnnotation(MultipartConfig.class);
+ if (multipartConfig != null) {
+ try {
+ var location = multipartConfig.location();
+ if (location == null || location.isBlank()) {
+ Path tmpDir = Files.createTempDirectory("TestServer_");
+ multiPartTmpDirs.add(tmpDir);
+ location = tmpDir.toString();
+ }
+ MultipartConfigElement multipartConfigElement =
+ new MultipartConfigElement(
+ location,
+ multipartConfig.maxFileSize(),
+ multipartConfig.maxRequestSize(),
+ multipartConfig.fileSizeThreshold());
+ holder.getRegistration().setMultipartConfig(multipartConfigElement);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
}
holder = context.addServlet(DefaultServlet.class, "/*");
holder.setInitParameter("aliases", "1");