1
0
mirror of https://github.com/google/nomulus synced 2026-05-31 12:06:32 +00:00

Address technical debt and improve safety in domain flows and models (#3065)

* Address technical debt and improve safety in domain flows and models

- Addressed unhandled empty lists and swallowed exceptions in DomainFlowTmchUtils.
- Improved null safety and immutability guarantees in Fee and LaunchPhase.
- Applied defensive copying in FeeTransformResponseExtension.
  Note: This uses the forceEmptyToNull(nullToEmptyImmutableCopy(...))
  pattern. This defensive copy ensures immutability, while forceEmptyToNull
  is required because JAXB will serialize an empty collection as an empty
  XML tag (which violates EPP XML schemas). Setting it to null ensures
  JAXB omits the tag entirely.
- Corrected JAXB property suppression in FeeCheckResponseExtensionItemStdV1.

* Add pr-polisher skill for automated PR pre-flight checks

* Enhance pr-polisher with more GEMINI.md constraints

Added checks for:
- Incorrect @Nullable imports.
- Unstatically imported utility methods (DateTimeUtils/CacheUtils).
- Redundant transaction wrapping (tm().transact -> tm().reTransact).
- Mutable collection instantiations (ArrayList/HashMap).
This commit is contained in:
Ben McIlwain
2026-05-27 17:45:11 -04:00
committed by GitHub
parent 00522fb618
commit 2bc07349a4
8 changed files with 368 additions and 27 deletions

View File

@@ -0,0 +1,40 @@
---
name: pr-polisher
description: Automated pre-flight checklist to polish PRs. Use this before declaring a task or PR complete to automatically verify license headers, commit hygiene, formatting, and codebase mandates.
---
# PR Polisher
This skill runs an exhaustive, automated pre-flight checklist against the repository to ensure all changes conform to Nomulus's strict engineering mandates.
## When to Use
You MUST activate and execute this workflow immediately before you are about to declare a PR, task, or codebase refactor "done" or ready for human review. Do not declare the task complete until this workflow succeeds with 0 errors.
## Workflow Execution Steps
1. **Run the Automated Analysis Script**
Execute the packaged Python diff-checker script. This script automatically checks commit messages, working tree status, `package-lock.json` modifications, copyright years on new files, and a litany of anti-patterns using regex (e.g., fully-qualified names, incorrect clock injections, generic exception catching).
```bash
python3 ./pr-polisher/scripts/check_diff.py
```
2. **Run Formatting Validation**
Always run the project's formatting tools to ensure checkstyle passes.
```bash
./gradlew spotlessCheck
# OR if formatting is needed:
./gradlew spotlessApply && ./gradlew javaIncrementalFormatApply
```
3. **Verify Test Coverage Additions**
Review your diff (`git diff HEAD^`). If you have added any *new* public methods or modified core logic, manually verify that you have added tests to the corresponding `Test.java` file. A code review is not thorough if it only checks for compilation.
4. **Address Errors & Amend**
If any script throws an error, or if formatting changes were applied, you must stage those fixes and amend your commit:
```bash
git add -u
git commit --amend --no-edit
```
Loop back to Step 1 until the `check_diff.py` script returns `0 ERRORS` and the working directory is clean.

View File

@@ -0,0 +1,189 @@
#!/usr/bin/env python3
import subprocess
import re
import sys
import datetime
# Color codes
RED = "\03.3[91m"
YELLOW = "\03.3[93m"
GREEN = "\03.3[92m"
RESET = "\03.3[0m"
errors_found = 0
warnings_found = 0
def log_error(msg):
global errors_found
errors_found += 1
print(f"{RED}[ERROR]{RESET} {msg}")
def log_warning(msg):
global warnings_found
warnings_found += 1
print(f"{YELLOW}[WARNING]{RESET} {msg}")
def log_success(msg):
print(f"{GREEN}[OK]{RESET} {msg}")
def run_cmd(cmd):
return subprocess.check_output(cmd, shell=True, text=True).strip()
def check_commit_message():
print("--- Checking Commit Message ---")
try:
msg = run_cmd("git log -1 --pretty=format:%B")
lines = msg.split('\n')
subject = lines[0]
if len(subject) > 50:
log_error(f"Commit subject exceeds 50 characters ({len(subject)} chars): '{subject}'")
if not subject[0].isupper():
log_error(f"Commit subject must be capitalized: '{subject}'")
if subject[-1] in ['.', '!', '?']:
log_error(f"Commit subject must not end with punctuation: '{subject}'")
else:
log_success("Commit message format looks good.")
except Exception as e:
log_error(f"Failed to check commit message: {e}")
def check_workspace_clean():
print("\n--- Checking Workspace State ---")
status = run_cmd("git status --porcelain")
if status:
log_error("Workspace is not clean. Uncommitted changes found:\n" + status)
else:
log_success("Working directory is clean.")
def check_package_lock():
print("\n--- Checking package-lock.json ---")
diff_files = run_cmd("git diff HEAD^ --name-only").split('\n')
if "console-webapp/package-lock.json" in diff_files:
log_error("console-webapp/package-lock.json is modified in the diff. Unless NPM dependencies were explicitly changed, revert this file using: git checkout console-webapp/package-lock.json")
else:
log_success("console-webapp/package-lock.json is untouched.")
def check_license_headers():
print("\n--- Checking License Headers on New Files ---")
current_year = str(datetime.datetime.now().year)
added_files = run_cmd("git diff HEAD^ --name-status --diff-filter=A").split('\n')
added_java_files = [f.split('\t')[-1] for f in added_files if f.endswith('.java')]
expected_header = f"// Copyright {current_year} The Nomulus Authors. All Rights Reserved."
for f in added_java_files:
try:
with open(f, 'r') as file:
content = file.read()
if expected_header not in content:
log_error(f"Missing or incorrect copyright year in {f}. Expected: {expected_header}")
except FileNotFoundError:
pass
if not added_java_files:
log_success("No new Java files added.")
def check_diff_anti_patterns():
print("\n--- Checking Code Anti-Patterns in Diff ---")
diff = run_cmd("git diff HEAD^ -U0")
current_file = ""
# Regex Patterns
fqn_pattern = re.compile(r'(?<!import\s)(java\.[a-z0-9.]+\.[A-Z][a-zA-Z0-9]+|google\.registry\.[a-z0-9.]+\.[A-Z][a-zA-Z0-9]+)')
visibility_pattern = re.compile(r'/\*\s*package\s*\*/')
utc_pattern = re.compile(r'ZoneId\.of\("UTC"\)')
now_pattern = re.compile(r'(Instant\.now\(\)|OffsetDateTime\.now\(\)|System\.currentTimeMillis\(\))')
catch_generic_pattern = re.compile(r'catch\s*\(\s*(Exception|Throwable)\s+[a-zA-Z0-9_]+\s*\)')
is_equal_optional_pattern = re.compile(r'\.isEqualTo\(Optional\.of\(')
sleep_pattern = re.compile(r'Thread\.sleep\(')
suppress_pattern = re.compile(r'@SuppressWarnings\(')
wrong_nullable_pattern = re.compile(r'import\s+(?!javax\.annotation\.Nullable;)[a-zA-Z0-9_.]+\.Nullable;')
utility_class_pattern = re.compile(r'\b(DateTimeUtils|CacheUtils)\.[a-z]')
redundant_tx_pattern = re.compile(r'tm\(\)\.transact\(\s*\(\)\s*->\s*tm\(\)\.reTransact')
mutable_collection_pattern = re.compile(r'new\s+(ArrayList|HashMap|HashSet)\s*[<()]')
suppress_count = 0
for line in diff.split('\n'):
if line.startswith('+++ b/'):
current_file = line[6:]
suppress_count = 0
continue
if line.startswith('+') and not line.startswith('+++') and current_file.endswith('.java'):
code_line = line[1:]
# FQN Check
fqn_matches = fqn_pattern.findall(code_line)
if fqn_matches:
# Skip if the match is exactly part of an import or package declaration
if not code_line.strip().startswith('import') and not code_line.strip().startswith('package'):
log_warning(f"[{current_file}] Potential Fully-Qualified Name found: {fqn_matches}. Use imports instead.")
# Package visibility
if visibility_pattern.search(code_line):
log_error(f"[{current_file}] Found '/* package */' modifier. Leave modifier blank instead.")
# Time zones
if utc_pattern.search(code_line):
log_error(f"[{current_file}] Found ZoneId.of(\"UTC\"). Use statically imported ZoneOffset.UTC instead.")
# System clocks
if now_pattern.search(code_line):
log_error(f"[{current_file}] Found un-injected clock (Instant.now / System.currentTimeMillis). Inject Clock instead.")
# Catch generic
if catch_generic_pattern.search(code_line):
log_warning(f"[{current_file}] Catching generic Exception/Throwable. Use specific exceptions.")
# Truth Optionals
if is_equal_optional_pattern.search(code_line):
log_warning(f"[{current_file}] Found .isEqualTo(Optional.of(...)). Use Truth's .hasValue(...) instead.")
# Thread.sleep
if sleep_pattern.search(code_line):
log_warning(f"[{current_file}] Found Thread.sleep(). Use Sleeper instead in tests.")
# SuppressWarnings
if suppress_pattern.search(code_line):
suppress_count += 1
if suppress_count > 1:
log_error(f"[{current_file}] Multiple @SuppressWarnings detected. They must be merged (e.g. {{\"unchecked\", \"foo\"}}).")
else:
suppress_count = 0
# Wrong Nullable
if wrong_nullable_pattern.search(code_line):
log_error(f"[{current_file}] Found incorrect Nullable import. Always use javax.annotation.Nullable.")
# Missing static imports for utilities
if utility_class_pattern.search(code_line):
if not code_line.strip().startswith('import'):
log_warning(f"[{current_file}] Found un-statically imported method from DateTimeUtils/CacheUtils. Use static imports.")
# Redundant transaction wrapping
if redundant_tx_pattern.search(code_line):
log_error(f"[{current_file}] Found redundant transaction wrapping (tm().transact(() -> tm().reTransact(...))).")
# Mutable collection instantiation
if mutable_collection_pattern.search(code_line):
log_warning(f"[{current_file}] Found mutable collection instantiation (ArrayList/HashMap/HashSet). Prefer Guava Immutable collections.")
def main():
print("========================================")
print(" NOMULUS PR POLISHER CHECKLIST ")
print("========================================\n")
check_commit_message()
check_workspace_clean()
check_package_lock()
check_license_headers()
check_diff_anti_patterns()
print("\n========================================")
if errors_found == 0 and warnings_found == 0:
print(f"{GREEN}SUCCESS: All checks passed. PR is polished!{RESET}")
else:
print(f"RESULTS: {RED}{errors_found} ERRORS{RESET}, {YELLOW}{warnings_found} WARNINGS{RESET}")
print("Please address the above issues before declaring the PR complete.")
sys.exit(1 if errors_found > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -57,6 +57,9 @@ public final class DomainFlowTmchUtils {
public SignedMark verifySignedMarks(
ImmutableList<AbstractSignedMark> signedMarks, String domainLabel, Instant now)
throws EppException {
if (signedMarks.isEmpty()) {
throw new SignedMarksListEmptyException();
}
if (signedMarks.size() > 1) {
throw new TooManySignedMarksException();
}
@@ -77,21 +80,21 @@ public final class DomainFlowTmchUtils {
public SignedMark verifyEncodedSignedMark(EncodedSignedMark encodedSignedMark, Instant now)
throws EppException {
if (!encodedSignedMark.getEncoding().equals("base64")) {
if (!"base64".equals(encodedSignedMark.getEncoding())) {
throw new Base64RequiredForEncodedSignedMarksException();
}
byte[] signedMarkData;
try {
signedMarkData = encodedSignedMark.getBytes();
} catch (IllegalStateException e) {
throw new SignedMarkEncodingErrorException();
throw new SignedMarkEncodingErrorException(e);
}
SignedMark signedMark;
try {
signedMark = unmarshalEpp(SignedMark.class, signedMarkData);
} catch (EppException e) {
throw new SignedMarkParsingErrorException();
throw new SignedMarkParsingErrorException(e);
}
if (SignedMarkRevocationList.get().isSmdRevoked(signedMark.getId(), now)) {
@@ -101,22 +104,22 @@ public final class DomainFlowTmchUtils {
try {
tmchXmlSignature.verify(signedMarkData);
} catch (CertificateExpiredException e) {
throw new SignedMarkCertificateExpiredException();
throw new SignedMarkCertificateExpiredException(e);
} catch (CertificateNotYetValidException e) {
throw new SignedMarkCertificateNotYetValidException();
throw new SignedMarkCertificateNotYetValidException(e);
} catch (CertificateRevokedException e) {
throw new SignedMarkCertificateRevokedException();
throw new SignedMarkCertificateRevokedException(e);
} catch (CertificateSignatureException e) {
throw new SignedMarkCertificateSignatureException();
throw new SignedMarkCertificateSignatureException(e);
} catch (SignatureException | XMLSignatureException e) {
throw new SignedMarkSignatureException();
throw new SignedMarkSignatureException(e);
} catch (GeneralSecurityException e) {
throw new SignedMarkCertificateInvalidException();
throw new SignedMarkCertificateInvalidException(e);
} catch (IOException
| MarshalException
| SAXException
| ParserConfigurationException e) {
throw new SignedMarkParsingErrorException();
throw new SignedMarkParsingErrorException(e);
}
if (now.isBefore(signedMark.getCreationTime())) {
@@ -181,6 +184,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkCertificateRevokedException() {
super("Signed mark certificate was revoked");
}
public SignedMarkCertificateRevokedException(Throwable cause) {
this();
initCause(cause);
}
}
/** Certificate used in signed mark signature has expired. */
@@ -189,6 +197,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkCertificateNotYetValidException() {
super("Signed mark certificate not yet valid");
}
public SignedMarkCertificateNotYetValidException(Throwable cause) {
this();
initCause(cause);
}
}
/** Certificate used in signed mark signature has expired. */
@@ -196,6 +209,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkCertificateExpiredException() {
super("Signed mark certificate has expired");
}
public SignedMarkCertificateExpiredException(Throwable cause) {
this();
initCause(cause);
}
}
/** Certificate parsing error, or possibly a bad provider or algorithm. */
@@ -203,6 +221,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkCertificateInvalidException() {
super("Signed mark certificate is invalid");
}
public SignedMarkCertificateInvalidException(Throwable cause) {
this();
initCause(cause);
}
}
/** Invalid signature on a signed mark. */
@@ -210,6 +233,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkCertificateSignatureException() {
super("Signed mark certificate not signed by ICANN");
}
public SignedMarkCertificateSignatureException(Throwable cause) {
this();
initCause(cause);
}
}
/** Invalid signature on a signed mark. */
@@ -217,6 +245,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkSignatureException() {
super("Signed mark signature is invalid");
}
public SignedMarkSignatureException(Throwable cause) {
this();
initCause(cause);
}
}
/** Signed marks must be encoded. */
@@ -226,6 +259,13 @@ public final class DomainFlowTmchUtils {
}
}
/** Signed marks list cannot be empty. */
static class SignedMarksListEmptyException extends RequiredParameterMissingException {
public SignedMarksListEmptyException() {
super("Signed marks list cannot be empty");
}
}
/** Only one signed mark is allowed per application. */
static class TooManySignedMarksException extends ParameterValuePolicyErrorException {
public TooManySignedMarksException() {
@@ -245,6 +285,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkParsingErrorException() {
super("Error while parsing encoded signed mark data");
}
public SignedMarkParsingErrorException(Throwable cause) {
this();
initCause(cause);
}
}
/** Signed mark data is improperly encoded. */
@@ -252,6 +297,11 @@ public final class DomainFlowTmchUtils {
public SignedMarkEncodingErrorException() {
super("Signed mark data is improperly encoded");
}
public SignedMarkEncodingErrorException(Throwable cause) {
this();
initCause(cause);
}
}
}

View File

@@ -15,7 +15,6 @@
package google.registry.model.domain.fee;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static google.registry.util.PreconditionsUtils.checkArgumentNotNull;
import com.google.common.collect.ImmutableSet;
@@ -31,6 +30,13 @@ import java.time.Period;
*/
public class Fee extends BaseFee {
public static final ImmutableSet<String> FEE_EXTENSION_URIS =
ImmutableSet.of(
ServiceExtension.FEE_1_00.getUri(),
ServiceExtension.FEE_0_12.getUri(),
ServiceExtension.FEE_0_11.getUri(),
ServiceExtension.FEE_0_6.getUri());
@Override
public Fee clone() {
return (Fee) super.clone();
@@ -60,21 +66,15 @@ public class Fee extends BaseFee {
private static Fee createWithCustomDescription(
BigDecimal cost, FeeType type, boolean isPremium, String description) {
Fee instance = new Fee();
instance.cost = checkNotNull(cost);
checkArgument(instance.cost.signum() >= 0, "Cost must be a non-negative number");
instance.type = checkNotNull(type);
checkArgumentNotNull(cost, "Cost cannot be null");
checkArgument(cost.signum() >= 0, "Cost must be a non-negative number");
instance.cost = cost;
instance.type = type;
instance.isPremium = isPremium;
instance.description = description;
return instance;
}
public static final ImmutableSet<String> FEE_EXTENSION_URIS =
ImmutableSet.of(
ServiceExtension.FEE_1_00.getUri(),
ServiceExtension.FEE_0_12.getUri(),
ServiceExtension.FEE_0_11.getUri(),
ServiceExtension.FEE_0_6.getUri());
/** Builder for {@link Fee}. */
public static class Builder extends Buildable.Builder<Fee> {

View File

@@ -63,12 +63,12 @@ public class FeeTransformResponseExtension extends ImmutableObject implements Re
}
public Builder setFees(List<Fee> fees) {
getInstance().fees = fees;
getInstance().fees = forceEmptyToNull(nullToEmptyImmutableCopy(fees));
return this;
}
public Builder setCredits(List<Credit> credits) {
getInstance().credits = forceEmptyToNull(credits);
getInstance().credits = forceEmptyToNull(nullToEmptyImmutableCopy(credits));
return this;
}
}

View File

@@ -14,13 +14,12 @@
package google.registry.model.domain.feestdv1;
import static google.registry.util.CollectionUtils.forceEmptyToNull;
import com.google.common.collect.ImmutableList;
import google.registry.model.domain.Period;
import google.registry.model.domain.fee.Fee;
import google.registry.model.domain.fee.FeeCheckResponseExtensionItem;
import google.registry.model.domain.fee.FeeQueryCommandExtensionItem.CommandName;
import jakarta.xml.bind.annotation.XmlTransient;
import jakarta.xml.bind.annotation.XmlType;
/** The version 1.0 response for a domain check on a single resource. */
@@ -38,6 +37,7 @@ public class FeeCheckResponseExtensionItemStdV1 extends FeeCheckResponseExtensio
* doesn't support "period".
*/
@Override
@XmlTransient
public Period getPeriod() {
return super.getPeriod();
}
@@ -47,6 +47,7 @@ public class FeeCheckResponseExtensionItemStdV1 extends FeeCheckResponseExtensio
* doesn't support "fee".
*/
@Override
@XmlTransient
public ImmutableList<Fee> getFees() {
return super.getFees();
}
@@ -74,7 +75,7 @@ public class FeeCheckResponseExtensionItemStdV1 extends FeeCheckResponseExtensio
@Override
public Builder setFees(ImmutableList<Fee> fees) {
commandBuilder.setFee(forceEmptyToNull(ImmutableList.copyOf(fees)));
commandBuilder.setFee(fees);
return this;
}

View File

@@ -20,6 +20,7 @@ import google.registry.model.ImmutableObject;
import jakarta.xml.bind.annotation.XmlAttribute;
import jakarta.xml.bind.annotation.XmlValue;
import java.util.Objects;
import javax.annotation.Nullable;
/**
* The launch phase of the TLD being addressed by this command.
@@ -46,7 +47,7 @@ import java.util.Objects;
* sets it is the one that needs to make sure the domain isn't a trademark and that the fields are
* correct.
*/
public class LaunchPhase extends ImmutableObject {
public final class LaunchPhase extends ImmutableObject {
/**
* The phase during which trademark holders can submit domain registrations with trademark
@@ -70,6 +71,9 @@ public class LaunchPhase extends ImmutableObject {
return instance;
}
/** Private no-arg constructor required for JAXB and to enforce immutability elsewhere. */
private LaunchPhase() {}
@XmlValue String phase;
/**
@@ -79,6 +83,7 @@ public class LaunchPhase extends ImmutableObject {
* <p>This is currently unused, but is retained so that incoming XMLs that include a subphase can
* have it be reflected back.
*/
@Nullable
@XmlAttribute(name = "name")
String subphase;

View File

@@ -0,0 +1,56 @@
// Copyright 2026 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.flows.domain;
import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.common.collect.ImmutableList;
import google.registry.flows.domain.DomainFlowTmchUtils.SignedMarksListEmptyException;
import google.registry.flows.domain.DomainFlowTmchUtils.SignedMarksMustBeEncodedException;
import google.registry.flows.domain.DomainFlowTmchUtils.TooManySignedMarksException;
import google.registry.model.smd.AbstractSignedMark;
import java.time.Instant;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
class DomainFlowTmchUtilsTest {
private final DomainFlowTmchUtils tmchUtils = new DomainFlowTmchUtils(null);
@Test
void test_verifySignedMarks_emptyList() {
assertThrows(
SignedMarksListEmptyException.class,
() -> tmchUtils.verifySignedMarks(ImmutableList.of(), "example", Instant.now()));
}
@Test
void test_verifySignedMarks_tooManyMarks() {
AbstractSignedMark mark1 = Mockito.mock(AbstractSignedMark.class);
AbstractSignedMark mark2 = Mockito.mock(AbstractSignedMark.class);
assertThrows(
TooManySignedMarksException.class,
() ->
tmchUtils.verifySignedMarks(ImmutableList.of(mark1, mark2), "example", Instant.now()));
}
@Test
void test_verifySignedMarks_notEncoded() {
AbstractSignedMark mark1 = Mockito.mock(AbstractSignedMark.class);
assertThrows(
SignedMarksMustBeEncodedException.class,
() -> tmchUtils.verifySignedMarks(ImmutableList.of(mark1), "example", Instant.now()));
}
}