1
0
mirror of https://github.com/google/nomulus synced 2026-01-03 19:54:18 +00:00

Fix OOM in UploadBsaUnavailableDomains action (#2817)

* Fix OOM in UploadBsaUnavailableDomains action

The action was using string concatenation to generate the upload content.
This causes an OOM when string length exceeds 25MB on our current VM.

This PR witches to streaming upload.

Also added an HTTP upload test.

* Fix OOM in UploadBsaUnavailableDomains action

The action was using string concatenation to generate the upload content.
This causes an OOM when string length exceeds 25MB on our current VM.

This PR witches to streaming upload.

Also added an HTTP upload test.
This commit is contained in:
Weimin Yu
2025-09-03 18:25:56 +00:00
committed by GitHub
parent 5e1cd0120f
commit 77ab80f3dc
4 changed files with 256 additions and 24 deletions

View File

@@ -58,6 +58,8 @@ def fragileTestPatterns = [
// Changes cache timeouts and for some reason appears to have contention // Changes cache timeouts and for some reason appears to have contention
// with other tests. // with other tests.
"google/registry/whois/WhoisCommandFactoryTest.*", "google/registry/whois/WhoisCommandFactoryTest.*",
// Breaks random other tests when running with standardTests.
"google/registry/bsa/UploadBsaUnavailableDomainsActionTest.*",
// Currently changes a global configuration parameter that for some reason // Currently changes a global configuration parameter that for some reason
// results in timestamp inversions for other tests. TODO(mmuller): fix. // results in timestamp inversions for other tests. TODO(mmuller): fix.
"google/registry/flows/host/HostInfoFlowTest.*", "google/registry/flows/host/HostInfoFlowTest.*",

View File

@@ -25,16 +25,16 @@ import static google.registry.request.Action.Method.GET;
import static google.registry.request.Action.Method.POST; import static google.registry.request.Action.Method.POST;
import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobId;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet; import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import com.google.common.flogger.FluentLogger; import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing; import com.google.common.hash.Hashing;
import com.google.common.io.ByteSource;
import google.registry.bsa.api.BsaCredential; import google.registry.bsa.api.BsaCredential;
import google.registry.config.RegistryConfig.Config; import google.registry.config.RegistryConfig.Config;
import google.registry.gcs.GcsUtils; import google.registry.gcs.GcsUtils;
@@ -47,10 +47,13 @@ import google.registry.request.auth.Auth;
import google.registry.util.Clock; import google.registry.util.Clock;
import jakarta.inject.Inject; import jakarta.inject.Inject;
import jakarta.persistence.TypedQuery; import jakarta.persistence.TypedQuery;
import java.io.ByteArrayOutputStream; import java.io.BufferedInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.io.OutputStreamWriter; import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.Writer; import java.io.Writer;
import java.util.Optional; import java.util.Optional;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
@@ -60,14 +63,17 @@ import okhttp3.OkHttpClient;
import okhttp3.Request; import okhttp3.Request;
import okhttp3.RequestBody; import okhttp3.RequestBody;
import okhttp3.Response; import okhttp3.Response;
import okio.BufferedSink;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.joda.time.DateTime; import org.joda.time.DateTime;
/** /**
* Daily action that uploads unavailable domain names on applicable TLDs to BSA. * Daily action that uploads unavailable domain names on applicable TLDs to BSA.
* *
* <p>The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs. * <p>The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs.
* The text is a newline-delimited list of punycoded fully qualified domain names, and contains all * The text is a newline-delimited list of punycoded fully qualified domain names with a trailing
* domains on each TLD that are registered and/or reserved. * newline at the end, and contains all domains on each TLD that are registered and/or reserved.
* *
* <p>The file is also uploaded to GCS to preserve it as a record for ourselves. * <p>The file is also uploaded to GCS to preserve it as a record for ourselves.
*/ */
@@ -118,7 +124,7 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
// TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit // TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit
// the run if the daily upload is already completed. // the run if the daily upload is already completed.
DateTime runTime = clock.nowUtc(); DateTime runTime = clock.nowUtc();
String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime)); ImmutableSortedSet<String> unavailableDomains = getUnavailableDomains(runTime);
if (unavailableDomains.isEmpty()) { if (unavailableDomains.isEmpty()) {
logger.atWarning().log("No unavailable domains found; terminating."); logger.atWarning().log("No unavailable domains found; terminating.");
emailSender.sendNotification( emailSender.sendNotification(
@@ -136,12 +142,16 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
} }
/** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */ /** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */
boolean uploadToGcs(String unavailableDomains, DateTime runTime) { boolean uploadToGcs(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket); logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket);
BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime)); BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime));
// `gcsUtils.openOutputStream` returns a buffered stream
try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId); try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId);
Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) { Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) {
osWriter.write(unavailableDomains); for (var domainName : unavailableDomains) {
osWriter.write(domainName);
osWriter.write("\n");
}
return true; return true;
} catch (Exception e) { } catch (Exception e) {
logger.atSevere().withCause(e).log( logger.atSevere().withCause(e).log(
@@ -150,10 +160,14 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
} }
} }
boolean uploadToBsa(String unavailableDomains, DateTime runTime) { boolean uploadToBsa(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
try { try {
byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains); Hasher sha512Hasher = Hashing.sha512().newHasher();
String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString(); unavailableDomains.stream()
.map(name -> name + "\n")
.forEachOrdered(line -> sha512Hasher.putString(line, UTF_8));
String sha512Hash = sha512Hasher.hash().toString();
String filename = createFilename(runTime); String filename = createFilename(runTime);
OkHttpClient client = new OkHttpClient().newBuilder().build(); OkHttpClient client = new OkHttpClient().newBuilder().build();
@@ -169,7 +183,9 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
.addFormDataPart( .addFormDataPart(
"file", "file",
String.format("%s.gz", filename), String.format("%s.gz", filename),
RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream"))) new StreamingRequestBody(
gzippedStream(unavailableDomains),
MediaType.parse("application/octet-stream")))
.build(); .build();
Request request = Request request =
@@ -196,15 +212,6 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
} }
} }
private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII));
}
return byteArrayOutputStream.toByteArray();
}
}
private static String createFilename(DateTime runTime) { private static String createFilename(DateTime runTime) {
return String.format("unavailable_domains_%s.txt", runTime.toString()); return String.format("unavailable_domains_%s.txt", runTime.toString());
} }
@@ -280,4 +287,65 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
private static String toDomain(String domainLabel, Tld tld) { private static String toDomain(String domainLabel, Tld tld) {
return String.format("%s.%s", domainLabel, tld.getTldStr()); return String.format("%s.%s", domainLabel, tld.getTldStr());
} }
private InputStream gzippedStream(ImmutableSortedSet<String> unavailableDomains)
throws IOException {
PipedInputStream inputStream = new PipedInputStream();
PipedOutputStream outputStream = new PipedOutputStream(inputStream);
new Thread(
() -> {
try {
gzipUnavailableDomains(outputStream, unavailableDomains);
} catch (Throwable e) {
logger.atSevere().withCause(e).log("Failed to gzip unavailable domains.");
try {
// This will cause the next read to throw an IOException.
inputStream.close();
} catch (IOException ignore) {
// Won't happen for `PipedInputStream.close()`
}
}
})
.start();
return inputStream;
}
private void gzipUnavailableDomains(
PipedOutputStream outputStream, ImmutableSortedSet<String> unavailableDomains)
throws IOException {
// `GZIPOutputStream` is buffered.
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) {
for (String name : unavailableDomains) {
var line = name + "\n";
gzipOutputStream.write(line.getBytes(US_ASCII));
}
}
}
private static class StreamingRequestBody extends RequestBody {
private final BufferedInputStream inputStream;
private final MediaType mediaType;
StreamingRequestBody(InputStream inputStream, MediaType mediaType) {
this.inputStream = new BufferedInputStream(inputStream);
this.mediaType = mediaType;
}
@Nullable
@Override
public MediaType contentType() {
return mediaType;
}
@Override
public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException {
byte[] buffer = new byte[2048];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
bufferedSink.write(buffer, 0, bytesRead);
}
}
}
} }

View File

@@ -20,13 +20,24 @@ import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistDeletedDomain; import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
import static google.registry.testing.DatabaseHelper.persistReservedList; import static google.registry.testing.DatabaseHelper.persistReservedList;
import static google.registry.testing.DatabaseHelper.persistResource; import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.LogsSubject.assertAboutLogs;
import static google.registry.util.DateTimeUtils.START_OF_TIME; import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static google.registry.util.NetworkUtils.pickUnusedPort;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper; import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hashing;
import com.google.common.io.ByteStreams;
import com.google.common.net.HostAndPort;
import com.google.common.testing.TestLogHandler;
import com.google.gson.Gson;
import google.registry.bsa.api.BsaCredential; import google.registry.bsa.api.BsaCredential;
import google.registry.gcs.GcsUtils; import google.registry.gcs.GcsUtils;
import google.registry.model.tld.Tld; import google.registry.model.tld.Tld;
@@ -35,9 +46,25 @@ import google.registry.model.tld.label.ReservedList;
import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.request.UrlConnectionService; import google.registry.request.UrlConnectionService;
import google.registry.server.Route;
import google.registry.server.TestServer;
import google.registry.testing.FakeClock; import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse; import google.registry.testing.FakeResponse;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.Part;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -102,13 +129,112 @@ public class UploadBsaUnavailableDomainsActionTest {
BlobId existingFile = BlobId existingFile =
BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc())); BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc()));
String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8); String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8);
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld"); assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n");
assertThat(blockList).doesNotContain("not-blocked.tld"); assertThat(blockList).doesNotContain("not-blocked.tld");
// This test currently fails in the upload-to-bsa step. // This test currently fails in the upload-to-bsa step.
verify(emailSender, times(1)) verify(emailSender, times(1))
.sendNotification("BSA daily upload completed with errors", "Please see logs for details."); .sendNotification("BSA daily upload completed with errors", "Please see logs for details.");
}
// TODO(mcilwain): Add test of BSA API upload as well. @Test
void uploadToBsaTest() throws Exception {
TestLogHandler logHandler = new TestLogHandler();
Logger loggerToIntercept =
Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName());
loggerToIntercept.addHandler(logHandler);
persistActiveDomain("foobar.tld");
persistActiveDomain("ace.tld");
persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1));
var testServer = startTestServer();
action.apiUrl = testServer.getUrl("/upload").toURI().toString();
try {
action.run();
} finally {
testServer.stop();
}
String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n";
String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString();
String expectedResponse =
"Received response with code 200 from server: "
+ String.format("Checksum: [%s]\n%s\n", checkSum, dataSent);
assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse);
verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", "");
}
private TestServer startTestServer() throws Exception {
TestServer testServer =
new TestServer(
HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()),
ImmutableMap.of(),
ImmutableList.of(Route.route("/upload", Servelet.class)));
testServer.start();
newSingleThreadExecutor()
.execute(
() -> {
try {
while (true) {
testServer.process();
}
} catch (InterruptedException e) {
// Expected
}
});
return testServer;
}
@MultipartConfig(
location = "", // Directory for storing uploaded files. Use default when blank
maxFileSize = 10485760L, // 10MB
maxRequestSize = 20971520L, // 20MB
fileSizeThreshold = 1048576 // Save in memory if file size < 1MB
)
public static class Servelet extends HttpServlet {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException {
String checkSum = null;
String content = null;
try {
for (Part part : req.getParts()) {
switch (part.getName()) {
case "zone" -> checkSum = readChecksum(part);
case "file" -> content = readGzipped(part);
}
}
} catch (Exception e) {
logger.atInfo().withCause(e).log("");
}
int status = checkSum == null || content == null ? 400 : 200;
resp.setStatus(status);
resp.setContentType("text/plain");
try (PrintWriter writer = resp.getWriter()) {
writer.printf("Checksum: [%s]\n%s\n", checkSum, content);
}
}
private String readChecksum(Part part) {
try (InputStream is = part.getInputStream()) {
return new Gson()
.fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class)
.getOrDefault("checkSum", "Not found")
.toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private String readGzipped(Part part) {
try (InputStream is = part.getInputStream();
GZIPInputStream gis = new GZIPInputStream(is)) {
return new String(ByteStreams.toByteArray(gis), UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} }
} }

View File

@@ -26,10 +26,16 @@ import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.SimpleTimeLimiter; import com.google.common.util.concurrent.SimpleTimeLimiter;
import google.registry.util.RegistryEnvironment; import google.registry.util.RegistryEnvironment;
import google.registry.util.UrlChecker; import google.registry.util.UrlChecker;
import jakarta.servlet.MultipartConfigElement;
import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServlet;
import java.io.IOException;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.FutureTask; import java.util.concurrent.FutureTask;
@@ -49,6 +55,8 @@ import org.eclipse.jetty.server.ServerConnector;
* {@link #stop()} methods. However, a {@link #process()} method was added, which is used to process * {@link #stop()} methods. However, a {@link #process()} method was added, which is used to process
* requests made to servlets (not static files) in the calling thread. * requests made to servlets (not static files) in the calling thread.
* *
* <p>A servlet that expects multi-part requests should be annotated with {@link MultipartConfig}.
*
* <p><b>Note:</b> This server is intended for development purposes. For the love all that is good, * <p><b>Note:</b> This server is intended for development purposes. For the love all that is good,
* do not make this public-facing. * do not make this public-facing.
* *
@@ -70,6 +78,7 @@ public final class TestServer {
private final HostAndPort urlAddress; private final HostAndPort urlAddress;
private final Server server = new Server(); private final Server server = new Server();
private final BlockingQueue<FutureTask<Void>> requestQueue = new LinkedBlockingDeque<>(); private final BlockingQueue<FutureTask<Void>> requestQueue = new LinkedBlockingDeque<>();
private List<Path> multiPartTmpDirs = new ArrayList<>();
/** /**
* Creates a new instance, but does not begin serving. * Creates a new instance, but does not begin serving.
@@ -134,6 +143,13 @@ public final class TestServer {
}, },
SHUTDOWN_TIMEOUT_MS, SHUTDOWN_TIMEOUT_MS,
TimeUnit.MILLISECONDS); TimeUnit.MILLISECONDS);
for (var dir : multiPartTmpDirs) {
try {
Files.delete(dir);
} catch (Exception e) {
// Ignore
}
}
} catch (Exception e) { } catch (Exception e) {
throwIfUnchecked(e); throwIfUnchecked(e);
throw new RuntimeException(e); throw new RuntimeException(e);
@@ -161,7 +177,27 @@ public final class TestServer {
StaticResourceServlet.configureServletHolder(holder, runfile.getKey(), runfile.getValue()); StaticResourceServlet.configureServletHolder(holder, runfile.getKey(), runfile.getValue());
} }
for (Route route : routes) { for (Route route : routes) {
context.addServlet(wrapServlet(route.servletClass()), route.path()); holder = context.addServlet(wrapServlet(route.servletClass()), route.path());
MultipartConfig multipartConfig = route.servletClass().getAnnotation(MultipartConfig.class);
if (multipartConfig != null) {
try {
var location = multipartConfig.location();
if (location == null || location.isBlank()) {
Path tmpDir = Files.createTempDirectory("TestServer_");
multiPartTmpDirs.add(tmpDir);
location = tmpDir.toString();
}
MultipartConfigElement multipartConfigElement =
new MultipartConfigElement(
location,
multipartConfig.maxFileSize(),
multipartConfig.maxRequestSize(),
multipartConfig.fileSizeThreshold());
holder.getRegistration().setMultipartConfig(multipartConfigElement);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} }
holder = context.addServlet(DefaultServlet.class, "/*"); holder = context.addServlet(DefaultServlet.class, "/*");
holder.setInitParameter("aliases", "1"); holder.setInitParameter("aliases", "1");