Swallow DNS allowlist parsing errors if at least one record is valid.

This commit is contained in:
miyuko
2025-10-15 01:51:51 +01:00
parent afae6e42f3
commit 87262e82f0
3 changed files with 84 additions and 14 deletions

View File

@@ -157,19 +157,35 @@ func authorizeDNSAllowlist(r *http.Request) (*Authorization, error) {
}
allowlistHostname := fmt.Sprintf("_git-pages-repository.%s", host)
repoURLs, err := net.LookupTXT(allowlistHostname)
records, err := net.LookupTXT(allowlistHostname)
if err != nil {
return nil, AuthError{http.StatusUnauthorized,
fmt.Sprintf("failed to look up DNS repository allowlist: %s TXT", allowlistHostname)}
}
for _, repoURL := range repoURLs {
if parsedURL, err := url.Parse(repoURL); err != nil {
return nil, AuthError{http.StatusBadRequest,
fmt.Sprintf("failed to parse URL: %s TXT %q", allowlistHostname, repoURL)}
var (
repoURLs []string
errs []error
)
for _, record := range records {
if parsedURL, err := url.Parse(record); err != nil {
errs = append(errs, fmt.Errorf("failed to parse URL: %s TXT %q", allowlistHostname, record))
} else if !parsedURL.IsAbs() {
return nil, AuthError{http.StatusBadRequest,
fmt.Sprintf("repository URL is not absolute: %s TXT %q", allowlistHostname, repoURL)}
errs = append(errs, fmt.Errorf("repository URL is not absolute: %s TXT %q", allowlistHostname, record))
} else {
repoURLs = append(repoURLs, record)
}
}
if len(repoURLs) == 0 {
if len(records) > 0 {
errs = append([]error{AuthError{http.StatusUnauthorized,
fmt.Sprintf("no valid DNS TXT records for %s", allowlistHostname)}},
errs...)
return nil, joinErrors(errs...)
} else {
return nil, AuthError{http.StatusUnauthorized,
fmt.Sprintf("no DNS TXT records found for %s", allowlistHostname)}
}
}
@@ -351,7 +367,7 @@ func AuthorizeMetadataRetrieval(r *http.Request) (*Authorization, error) {
}
}
return nil, errors.Join(causes...)
return nil, joinErrors(causes...)
}
// Returns `repoURLs, err` where if `err == nil` then the request is authorized to clone from
@@ -421,7 +437,7 @@ func AuthorizeUpdateFromRepository(r *http.Request) (*Authorization, error) {
}
}
return nil, errors.Join(causes...)
return nil, joinErrors(causes...)
}
func AuthorizeRepository(repoURL string, auth *Authorization) error {
@@ -511,7 +527,7 @@ func AuthorizeUpdateFromArchive(r *http.Request) (*Authorization, error) {
return auth, nil
}
return nil, errors.Join(causes...)
return nil, joinErrors(causes...)
}
func CheckForbiddenDomain(r *http.Request) error {

View File

@@ -562,9 +562,7 @@ func ServePages(w http.ResponseWriter, r *http.Request) {
if err != nil {
var authErr AuthError
if errors.As(err, &authErr) {
message := fmt.Sprint(err)
http.Error(w, strings.ReplaceAll(message, "\n", "\n- "), authErr.code)
err = errors.New(strings.ReplaceAll(message, "\n", "; "))
http.Error(w, prettyErrMsg(err), authErr.code)
}
var tooLargeErr *http.MaxBytesError
if errors.As(err, &tooLargeErr) {

View File

@@ -1,6 +1,10 @@
package main
import "io"
import (
"errors"
"io"
"strings"
)
type BoundedReader struct {
inner io.Reader
@@ -23,3 +27,55 @@ func (reader *BoundedReader) Read(dest []byte) (count int, err error) {
reader.fuel -= int64(count)
return
}
type prettyError interface {
error
Pretty() string
}
func prettyErrMsg(err error) string {
switch cerr := err.(type) {
case prettyError:
return cerr.Pretty()
default:
return cerr.Error()
}
}
type prettyJoinError struct {
errs []error
}
func joinErrors(errs ...error) error {
if err := errors.Join(errs...); err != nil {
wrapErr := err.(interface{ Unwrap() []error })
return &prettyJoinError{errs: wrapErr.Unwrap()}
}
return nil
}
func (e *prettyJoinError) Error() string {
var s strings.Builder
for i, err := range e.errs {
if i > 0 {
s.WriteString("; ")
}
s.WriteString(err.Error())
}
return s.String()
}
func (e *prettyJoinError) Pretty() string {
var s strings.Builder
for i, err := range e.errs {
if i > 0 {
s.WriteString("\n- ")
}
s.WriteString(strings.ReplaceAll(prettyErrMsg(err), "\n", "\n "))
}
return s.String()
}
func (e *prettyJoinError) Unwrap() []error {
return e.errs
}