diff --git a/core/build.gradle b/core/build.gradle index 8e0d28b23..b86ce81bf 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -58,6 +58,8 @@ def fragileTestPatterns = [ // Changes cache timeouts and for some reason appears to have contention // with other tests. "google/registry/whois/WhoisCommandFactoryTest.*", + // Breaks random other tests when running with standardTests. + "google/registry/bsa/UploadBsaUnavailableDomainsActionTest.*", // Currently changes a global configuration parameter that for some reason // results in timestamp inversions for other tests. TODO(mmuller): fix. "google/registry/flows/host/HostInfoFlowTest.*", diff --git a/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java b/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java index c30ea788a..ea3bac08a 100644 --- a/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java +++ b/core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java @@ -25,16 +25,16 @@ import static google.registry.request.Action.Method.GET; import static google.registry.request.Action.Method.POST; import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; import com.google.cloud.storage.BlobId; -import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedSet; import com.google.common.collect.Ordering; import com.google.common.flogger.FluentLogger; +import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; -import com.google.common.io.ByteSource; import google.registry.bsa.api.BsaCredential; import google.registry.config.RegistryConfig.Config; import google.registry.gcs.GcsUtils; @@ -47,10 +47,13 @@ import google.registry.request.auth.Auth; import google.registry.util.Clock; import jakarta.inject.Inject; import jakarta.persistence.TypedQuery; -import java.io.ByteArrayOutputStream; +import java.io.BufferedInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.OutputStreamWriter; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; import java.io.Writer; import java.util.Optional; import java.util.zip.GZIPOutputStream; @@ -60,14 +63,17 @@ import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; +import okio.BufferedSink; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import org.joda.time.DateTime; /** * Daily action that uploads unavailable domain names on applicable TLDs to BSA. * *

The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs. - * The text is a newline-delimited list of punycoded fully qualified domain names, and contains all - * domains on each TLD that are registered and/or reserved. + * The text is a newline-delimited list of punycoded fully qualified domain names with a trailing + * newline at the end, and contains all domains on each TLD that are registered and/or reserved. * *

The file is also uploaded to GCS to preserve it as a record for ourselves. */ @@ -118,7 +124,7 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { // TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit // the run if the daily upload is already completed. DateTime runTime = clock.nowUtc(); - String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime)); + ImmutableSortedSet unavailableDomains = getUnavailableDomains(runTime); if (unavailableDomains.isEmpty()) { logger.atWarning().log("No unavailable domains found; terminating."); emailSender.sendNotification( @@ -136,12 +142,16 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { } /** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */ - boolean uploadToGcs(String unavailableDomains, DateTime runTime) { + boolean uploadToGcs(ImmutableSortedSet unavailableDomains, DateTime runTime) { logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket); BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime)); + // `gcsUtils.openOutputStream` returns a buffered stream try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId); Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) { - osWriter.write(unavailableDomains); + for (var domainName : unavailableDomains) { + osWriter.write(domainName); + osWriter.write("\n"); + } return true; } catch (Exception e) { logger.atSevere().withCause(e).log( @@ -150,10 +160,14 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { } } - boolean uploadToBsa(String unavailableDomains, DateTime runTime) { + boolean uploadToBsa(ImmutableSortedSet unavailableDomains, DateTime runTime) { try { - byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains); - String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString(); + Hasher sha512Hasher = Hashing.sha512().newHasher(); + unavailableDomains.stream() + .map(name -> name + "\n") + .forEachOrdered(line -> sha512Hasher.putString(line, UTF_8)); + String sha512Hash = sha512Hasher.hash().toString(); + String filename = createFilename(runTime); OkHttpClient client = new OkHttpClient().newBuilder().build(); @@ -169,7 +183,9 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { .addFormDataPart( "file", String.format("%s.gz", filename), - RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream"))) + new StreamingRequestBody( + gzippedStream(unavailableDomains), + MediaType.parse("application/octet-stream"))) .build(); Request request = @@ -196,15 +212,6 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { } } - private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException { - try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) { - try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) { - gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII)); - } - return byteArrayOutputStream.toByteArray(); - } - } - private static String createFilename(DateTime runTime) { return String.format("unavailable_domains_%s.txt", runTime.toString()); } @@ -280,4 +287,65 @@ public class UploadBsaUnavailableDomainsAction implements Runnable { private static String toDomain(String domainLabel, Tld tld) { return String.format("%s.%s", domainLabel, tld.getTldStr()); } + + private InputStream gzippedStream(ImmutableSortedSet unavailableDomains) + throws IOException { + PipedInputStream inputStream = new PipedInputStream(); + PipedOutputStream outputStream = new PipedOutputStream(inputStream); + + new Thread( + () -> { + try { + gzipUnavailableDomains(outputStream, unavailableDomains); + } catch (Throwable e) { + logger.atSevere().withCause(e).log("Failed to gzip unavailable domains."); + try { + // This will cause the next read to throw an IOException. + inputStream.close(); + } catch (IOException ignore) { + // Won't happen for `PipedInputStream.close()` + } + } + }) + .start(); + + return inputStream; + } + + private void gzipUnavailableDomains( + PipedOutputStream outputStream, ImmutableSortedSet unavailableDomains) + throws IOException { + // `GZIPOutputStream` is buffered. + try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) { + for (String name : unavailableDomains) { + var line = name + "\n"; + gzipOutputStream.write(line.getBytes(US_ASCII)); + } + } + } + + private static class StreamingRequestBody extends RequestBody { + private final BufferedInputStream inputStream; + private final MediaType mediaType; + + StreamingRequestBody(InputStream inputStream, MediaType mediaType) { + this.inputStream = new BufferedInputStream(inputStream); + this.mediaType = mediaType; + } + + @Nullable + @Override + public MediaType contentType() { + return mediaType; + } + + @Override + public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException { + byte[] buffer = new byte[2048]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + bufferedSink.write(buffer, 0, bytesRead); + } + } + } } diff --git a/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java b/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java index 7c078e5d4..26a1831bb 100644 --- a/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java +++ b/core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java @@ -20,13 +20,24 @@ import static google.registry.testing.DatabaseHelper.persistActiveDomain; import static google.registry.testing.DatabaseHelper.persistDeletedDomain; import static google.registry.testing.DatabaseHelper.persistReservedList; import static google.registry.testing.DatabaseHelper.persistResource; +import static google.registry.testing.LogsSubject.assertAboutLogs; import static google.registry.util.DateTimeUtils.START_OF_TIME; +import static google.registry.util.NetworkUtils.pickUnusedPort; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.Executors.newSingleThreadExecutor; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.cloud.storage.BlobId; import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.flogger.FluentLogger; +import com.google.common.hash.Hashing; +import com.google.common.io.ByteStreams; +import com.google.common.net.HostAndPort; +import com.google.common.testing.TestLogHandler; +import com.google.gson.Gson; import google.registry.bsa.api.BsaCredential; import google.registry.gcs.GcsUtils; import google.registry.model.tld.Tld; @@ -35,9 +46,25 @@ import google.registry.model.tld.label.ReservedList; import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension; import google.registry.request.UrlConnectionService; +import google.registry.server.Route; +import google.registry.server.TestServer; import google.registry.testing.FakeClock; import google.registry.testing.FakeResponse; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.MultipartConfig; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.Part; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.net.InetAddress; +import java.util.Map; import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.zip.GZIPInputStream; import org.joda.time.DateTime; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -102,13 +129,112 @@ public class UploadBsaUnavailableDomainsActionTest { BlobId existingFile = BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc())); String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8); - assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld"); + assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n"); assertThat(blockList).doesNotContain("not-blocked.tld"); // This test currently fails in the upload-to-bsa step. verify(emailSender, times(1)) .sendNotification("BSA daily upload completed with errors", "Please see logs for details."); + } - // TODO(mcilwain): Add test of BSA API upload as well. + @Test + void uploadToBsaTest() throws Exception { + TestLogHandler logHandler = new TestLogHandler(); + Logger loggerToIntercept = + Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName()); + loggerToIntercept.addHandler(logHandler); + + persistActiveDomain("foobar.tld"); + persistActiveDomain("ace.tld"); + persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1)); + + var testServer = startTestServer(); + action.apiUrl = testServer.getUrl("/upload").toURI().toString(); + try { + action.run(); + } finally { + testServer.stop(); + } + String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n"; + String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString(); + String expectedResponse = + "Received response with code 200 from server: " + + String.format("Checksum: [%s]\n%s\n", checkSum, dataSent); + assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse); + verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", ""); + } + + private TestServer startTestServer() throws Exception { + TestServer testServer = + new TestServer( + HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()), + ImmutableMap.of(), + ImmutableList.of(Route.route("/upload", Servelet.class))); + testServer.start(); + newSingleThreadExecutor() + .execute( + () -> { + try { + while (true) { + testServer.process(); + } + } catch (InterruptedException e) { + // Expected + } + }); + return testServer; + } + + @MultipartConfig( + location = "", // Directory for storing uploaded files. Use default when blank + maxFileSize = 10485760L, // 10MB + maxRequestSize = 20971520L, // 20MB + fileSizeThreshold = 1048576 // Save in memory if file size < 1MB + ) + public static class Servelet extends HttpServlet { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + String checkSum = null; + String content = null; + try { + for (Part part : req.getParts()) { + switch (part.getName()) { + case "zone" -> checkSum = readChecksum(part); + case "file" -> content = readGzipped(part); + } + } + } catch (Exception e) { + logger.atInfo().withCause(e).log(""); + } + int status = checkSum == null || content == null ? 400 : 200; + resp.setStatus(status); + resp.setContentType("text/plain"); + try (PrintWriter writer = resp.getWriter()) { + writer.printf("Checksum: [%s]\n%s\n", checkSum, content); + } + } + + private String readChecksum(Part part) { + try (InputStream is = part.getInputStream()) { + return new Gson() + .fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class) + .getOrDefault("checkSum", "Not found") + .toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private String readGzipped(Part part) { + try (InputStream is = part.getInputStream(); + GZIPInputStream gis = new GZIPInputStream(is)) { + return new String(ByteStreams.toByteArray(gis), UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } } diff --git a/core/src/test/java/google/registry/server/TestServer.java b/core/src/test/java/google/registry/server/TestServer.java index 2e457b4da..c37a15358 100644 --- a/core/src/test/java/google/registry/server/TestServer.java +++ b/core/src/test/java/google/registry/server/TestServer.java @@ -26,10 +26,16 @@ import com.google.common.net.HostAndPort; import com.google.common.util.concurrent.SimpleTimeLimiter; import google.registry.util.RegistryEnvironment; import google.registry.util.UrlChecker; +import jakarta.servlet.MultipartConfigElement; +import jakarta.servlet.annotation.MultipartConfig; import jakarta.servlet.http.HttpServlet; +import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.FutureTask; @@ -49,6 +55,8 @@ import org.eclipse.jetty.server.ServerConnector; * {@link #stop()} methods. However, a {@link #process()} method was added, which is used to process * requests made to servlets (not static files) in the calling thread. * + *

A servlet that expects multi-part requests should be annotated with {@link MultipartConfig}. + * *

Note: This server is intended for development purposes. For the love all that is good, * do not make this public-facing. * @@ -70,6 +78,7 @@ public final class TestServer { private final HostAndPort urlAddress; private final Server server = new Server(); private final BlockingQueue> requestQueue = new LinkedBlockingDeque<>(); + private List multiPartTmpDirs = new ArrayList<>(); /** * Creates a new instance, but does not begin serving. @@ -134,6 +143,13 @@ public final class TestServer { }, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS); + for (var dir : multiPartTmpDirs) { + try { + Files.delete(dir); + } catch (Exception e) { + // Ignore + } + } } catch (Exception e) { throwIfUnchecked(e); throw new RuntimeException(e); @@ -161,7 +177,27 @@ public final class TestServer { StaticResourceServlet.configureServletHolder(holder, runfile.getKey(), runfile.getValue()); } for (Route route : routes) { - context.addServlet(wrapServlet(route.servletClass()), route.path()); + holder = context.addServlet(wrapServlet(route.servletClass()), route.path()); + MultipartConfig multipartConfig = route.servletClass().getAnnotation(MultipartConfig.class); + if (multipartConfig != null) { + try { + var location = multipartConfig.location(); + if (location == null || location.isBlank()) { + Path tmpDir = Files.createTempDirectory("TestServer_"); + multiPartTmpDirs.add(tmpDir); + location = tmpDir.toString(); + } + MultipartConfigElement multipartConfigElement = + new MultipartConfigElement( + location, + multipartConfig.maxFileSize(), + multipartConfig.maxRequestSize(), + multipartConfig.fileSizeThreshold()); + holder.getRegistration().setMultipartConfig(multipartConfigElement); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } holder = context.addServlet(DefaultServlet.class, "/*"); holder.setInitParameter("aliases", "1");