1
0
mirror of https://github.com/google/nomulus synced 2025-12-23 06:15:42 +00:00

Fix OOM in UploadBsaUnavailableDomains action (#2817)

* Fix OOM in UploadBsaUnavailableDomains action

The action was using string concatenation to generate the upload content.
This causes an OOM when string length exceeds 25MB on our current VM.

This PR witches to streaming upload.

Also added an HTTP upload test.

* Fix OOM in UploadBsaUnavailableDomains action

The action was using string concatenation to generate the upload content.
This causes an OOM when string length exceeds 25MB on our current VM.

This PR witches to streaming upload.

Also added an HTTP upload test.
This commit is contained in:
Weimin Yu
2025-09-03 18:25:56 +00:00
committed by GitHub
parent 5e1cd0120f
commit 77ab80f3dc
4 changed files with 256 additions and 24 deletions

View File

@@ -58,6 +58,8 @@ def fragileTestPatterns = [
// Changes cache timeouts and for some reason appears to have contention
// with other tests.
"google/registry/whois/WhoisCommandFactoryTest.*",
// Breaks random other tests when running with standardTests.
"google/registry/bsa/UploadBsaUnavailableDomainsActionTest.*",
// Currently changes a global configuration parameter that for some reason
// results in timestamp inversions for other tests. TODO(mmuller): fix.
"google/registry/flows/host/HostInfoFlowTest.*",

View File

@@ -25,16 +25,16 @@ import static google.registry.request.Action.Method.GET;
import static google.registry.request.Action.Method.POST;
import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.cloud.storage.BlobId;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Ordering;
import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import com.google.common.io.ByteSource;
import google.registry.bsa.api.BsaCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.gcs.GcsUtils;
@@ -47,10 +47,13 @@ import google.registry.request.auth.Auth;
import google.registry.util.Clock;
import jakarta.inject.Inject;
import jakarta.persistence.TypedQuery;
import java.io.ByteArrayOutputStream;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.Writer;
import java.util.Optional;
import java.util.zip.GZIPOutputStream;
@@ -60,14 +63,17 @@ import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okio.BufferedSink;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.joda.time.DateTime;
/**
* Daily action that uploads unavailable domain names on applicable TLDs to BSA.
*
* <p>The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs.
* The text is a newline-delimited list of punycoded fully qualified domain names, and contains all
* domains on each TLD that are registered and/or reserved.
* The text is a newline-delimited list of punycoded fully qualified domain names with a trailing
* newline at the end, and contains all domains on each TLD that are registered and/or reserved.
*
* <p>The file is also uploaded to GCS to preserve it as a record for ourselves.
*/
@@ -118,7 +124,7 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
// TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit
// the run if the daily upload is already completed.
DateTime runTime = clock.nowUtc();
String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime));
ImmutableSortedSet<String> unavailableDomains = getUnavailableDomains(runTime);
if (unavailableDomains.isEmpty()) {
logger.atWarning().log("No unavailable domains found; terminating.");
emailSender.sendNotification(
@@ -136,12 +142,16 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
/** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */
boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
boolean uploadToGcs(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket);
BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime));
// `gcsUtils.openOutputStream` returns a buffered stream
try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId);
Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) {
osWriter.write(unavailableDomains);
for (var domainName : unavailableDomains) {
osWriter.write(domainName);
osWriter.write("\n");
}
return true;
} catch (Exception e) {
logger.atSevere().withCause(e).log(
@@ -150,10 +160,14 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
}
boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
boolean uploadToBsa(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
try {
byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains);
String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString();
Hasher sha512Hasher = Hashing.sha512().newHasher();
unavailableDomains.stream()
.map(name -> name + "\n")
.forEachOrdered(line -> sha512Hasher.putString(line, UTF_8));
String sha512Hash = sha512Hasher.hash().toString();
String filename = createFilename(runTime);
OkHttpClient client = new OkHttpClient().newBuilder().build();
@@ -169,7 +183,9 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
.addFormDataPart(
"file",
String.format("%s.gz", filename),
RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream")))
new StreamingRequestBody(
gzippedStream(unavailableDomains),
MediaType.parse("application/octet-stream")))
.build();
Request request =
@@ -196,15 +212,6 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
}
}
private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII));
}
return byteArrayOutputStream.toByteArray();
}
}
private static String createFilename(DateTime runTime) {
return String.format("unavailable_domains_%s.txt", runTime.toString());
}
@@ -280,4 +287,65 @@ public class UploadBsaUnavailableDomainsAction implements Runnable {
private static String toDomain(String domainLabel, Tld tld) {
return String.format("%s.%s", domainLabel, tld.getTldStr());
}
private InputStream gzippedStream(ImmutableSortedSet<String> unavailableDomains)
throws IOException {
PipedInputStream inputStream = new PipedInputStream();
PipedOutputStream outputStream = new PipedOutputStream(inputStream);
new Thread(
() -> {
try {
gzipUnavailableDomains(outputStream, unavailableDomains);
} catch (Throwable e) {
logger.atSevere().withCause(e).log("Failed to gzip unavailable domains.");
try {
// This will cause the next read to throw an IOException.
inputStream.close();
} catch (IOException ignore) {
// Won't happen for `PipedInputStream.close()`
}
}
})
.start();
return inputStream;
}
private void gzipUnavailableDomains(
PipedOutputStream outputStream, ImmutableSortedSet<String> unavailableDomains)
throws IOException {
// `GZIPOutputStream` is buffered.
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) {
for (String name : unavailableDomains) {
var line = name + "\n";
gzipOutputStream.write(line.getBytes(US_ASCII));
}
}
}
private static class StreamingRequestBody extends RequestBody {
private final BufferedInputStream inputStream;
private final MediaType mediaType;
StreamingRequestBody(InputStream inputStream, MediaType mediaType) {
this.inputStream = new BufferedInputStream(inputStream);
this.mediaType = mediaType;
}
@Nullable
@Override
public MediaType contentType() {
return mediaType;
}
@Override
public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException {
byte[] buffer = new byte[2048];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
bufferedSink.write(buffer, 0, bytesRead);
}
}
}
}

View File

@@ -20,13 +20,24 @@ import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
import static google.registry.testing.DatabaseHelper.persistReservedList;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.LogsSubject.assertAboutLogs;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static google.registry.util.NetworkUtils.pickUnusedPort;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hashing;
import com.google.common.io.ByteStreams;
import com.google.common.net.HostAndPort;
import com.google.common.testing.TestLogHandler;
import com.google.gson.Gson;
import google.registry.bsa.api.BsaCredential;
import google.registry.gcs.GcsUtils;
import google.registry.model.tld.Tld;
@@ -35,9 +46,25 @@ import google.registry.model.tld.label.ReservedList;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.request.UrlConnectionService;
import google.registry.server.Route;
import google.registry.server.TestServer;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.Part;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import org.joda.time.DateTime;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -102,13 +129,112 @@ public class UploadBsaUnavailableDomainsActionTest {
BlobId existingFile =
BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc()));
String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8);
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld");
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n");
assertThat(blockList).doesNotContain("not-blocked.tld");
// This test currently fails in the upload-to-bsa step.
verify(emailSender, times(1))
.sendNotification("BSA daily upload completed with errors", "Please see logs for details.");
}
// TODO(mcilwain): Add test of BSA API upload as well.
@Test
void uploadToBsaTest() throws Exception {
TestLogHandler logHandler = new TestLogHandler();
Logger loggerToIntercept =
Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName());
loggerToIntercept.addHandler(logHandler);
persistActiveDomain("foobar.tld");
persistActiveDomain("ace.tld");
persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1));
var testServer = startTestServer();
action.apiUrl = testServer.getUrl("/upload").toURI().toString();
try {
action.run();
} finally {
testServer.stop();
}
String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n";
String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString();
String expectedResponse =
"Received response with code 200 from server: "
+ String.format("Checksum: [%s]\n%s\n", checkSum, dataSent);
assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse);
verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", "");
}
private TestServer startTestServer() throws Exception {
TestServer testServer =
new TestServer(
HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()),
ImmutableMap.of(),
ImmutableList.of(Route.route("/upload", Servelet.class)));
testServer.start();
newSingleThreadExecutor()
.execute(
() -> {
try {
while (true) {
testServer.process();
}
} catch (InterruptedException e) {
// Expected
}
});
return testServer;
}
@MultipartConfig(
location = "", // Directory for storing uploaded files. Use default when blank
maxFileSize = 10485760L, // 10MB
maxRequestSize = 20971520L, // 20MB
fileSizeThreshold = 1048576 // Save in memory if file size < 1MB
)
public static class Servelet extends HttpServlet {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException {
String checkSum = null;
String content = null;
try {
for (Part part : req.getParts()) {
switch (part.getName()) {
case "zone" -> checkSum = readChecksum(part);
case "file" -> content = readGzipped(part);
}
}
} catch (Exception e) {
logger.atInfo().withCause(e).log("");
}
int status = checkSum == null || content == null ? 400 : 200;
resp.setStatus(status);
resp.setContentType("text/plain");
try (PrintWriter writer = resp.getWriter()) {
writer.printf("Checksum: [%s]\n%s\n", checkSum, content);
}
}
private String readChecksum(Part part) {
try (InputStream is = part.getInputStream()) {
return new Gson()
.fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class)
.getOrDefault("checkSum", "Not found")
.toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private String readGzipped(Part part) {
try (InputStream is = part.getInputStream();
GZIPInputStream gis = new GZIPInputStream(is)) {
return new String(ByteStreams.toByteArray(gis), UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}

View File

@@ -26,10 +26,16 @@ import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import google.registry.util.RegistryEnvironment;
import google.registry.util.UrlChecker;
import jakarta.servlet.MultipartConfigElement;
import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.FutureTask;
@@ -49,6 +55,8 @@ import org.eclipse.jetty.server.ServerConnector;
* {@link #stop()} methods. However, a {@link #process()} method was added, which is used to process
* requests made to servlets (not static files) in the calling thread.
*
* <p>A servlet that expects multi-part requests should be annotated with {@link MultipartConfig}.
*
* <p><b>Note:</b> This server is intended for development purposes. For the love all that is good,
* do not make this public-facing.
*
@@ -70,6 +78,7 @@ public final class TestServer {
private final HostAndPort urlAddress;
private final Server server = new Server();
private final BlockingQueue<FutureTask<Void>> requestQueue = new LinkedBlockingDeque<>();
private List<Path> multiPartTmpDirs = new ArrayList<>();
/**
* Creates a new instance, but does not begin serving.
@@ -134,6 +143,13 @@ public final class TestServer {
},
SHUTDOWN_TIMEOUT_MS,
TimeUnit.MILLISECONDS);
for (var dir : multiPartTmpDirs) {
try {
Files.delete(dir);
} catch (Exception e) {
// Ignore
}
}
} catch (Exception e) {
throwIfUnchecked(e);
throw new RuntimeException(e);
@@ -161,7 +177,27 @@ public final class TestServer {
StaticResourceServlet.configureServletHolder(holder, runfile.getKey(), runfile.getValue());
}
for (Route route : routes) {
context.addServlet(wrapServlet(route.servletClass()), route.path());
holder = context.addServlet(wrapServlet(route.servletClass()), route.path());
MultipartConfig multipartConfig = route.servletClass().getAnnotation(MultipartConfig.class);
if (multipartConfig != null) {
try {
var location = multipartConfig.location();
if (location == null || location.isBlank()) {
Path tmpDir = Files.createTempDirectory("TestServer_");
multiPartTmpDirs.add(tmpDir);
location = tmpDir.toString();
}
MultipartConfigElement multipartConfigElement =
new MultipartConfigElement(
location,
multipartConfig.maxFileSize(),
multipartConfig.maxRequestSize(),
multipartConfig.fileSizeThreshold());
holder.getRegistration().setMultipartConfig(multipartConfigElement);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
holder = context.addServlet(DefaultServlet.class, "/*");
holder.setInitParameter("aliases", "1");