From 87262e82f0b24c0ce0f8b092c72b75c8a205ec5a Mon Sep 17 00:00:00 2001 From: miyuko Date: Wed, 15 Oct 2025 01:51:51 +0100 Subject: [PATCH] Swallow DNS allowlist parsing errors if at least one record is valid. --- src/auth.go | 36 +++++++++++++++++++++++--------- src/pages.go | 4 +--- src/util.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 14 deletions(-) diff --git a/src/auth.go b/src/auth.go index facc1e8..9aa8666 100644 --- a/src/auth.go +++ b/src/auth.go @@ -157,19 +157,35 @@ func authorizeDNSAllowlist(r *http.Request) (*Authorization, error) { } allowlistHostname := fmt.Sprintf("_git-pages-repository.%s", host) - repoURLs, err := net.LookupTXT(allowlistHostname) + records, err := net.LookupTXT(allowlistHostname) if err != nil { return nil, AuthError{http.StatusUnauthorized, fmt.Sprintf("failed to look up DNS repository allowlist: %s TXT", allowlistHostname)} } - for _, repoURL := range repoURLs { - if parsedURL, err := url.Parse(repoURL); err != nil { - return nil, AuthError{http.StatusBadRequest, - fmt.Sprintf("failed to parse URL: %s TXT %q", allowlistHostname, repoURL)} + var ( + repoURLs []string + errs []error + ) + for _, record := range records { + if parsedURL, err := url.Parse(record); err != nil { + errs = append(errs, fmt.Errorf("failed to parse URL: %s TXT %q", allowlistHostname, record)) } else if !parsedURL.IsAbs() { - return nil, AuthError{http.StatusBadRequest, - fmt.Sprintf("repository URL is not absolute: %s TXT %q", allowlistHostname, repoURL)} + errs = append(errs, fmt.Errorf("repository URL is not absolute: %s TXT %q", allowlistHostname, record)) + } else { + repoURLs = append(repoURLs, record) + } + } + + if len(repoURLs) == 0 { + if len(records) > 0 { + errs = append([]error{AuthError{http.StatusUnauthorized, + fmt.Sprintf("no valid DNS TXT records for %s", allowlistHostname)}}, + errs...) + return nil, joinErrors(errs...) + } else { + return nil, AuthError{http.StatusUnauthorized, + fmt.Sprintf("no DNS TXT records found for %s", allowlistHostname)} } } @@ -351,7 +367,7 @@ func AuthorizeMetadataRetrieval(r *http.Request) (*Authorization, error) { } } - return nil, errors.Join(causes...) + return nil, joinErrors(causes...) } // Returns `repoURLs, err` where if `err == nil` then the request is authorized to clone from @@ -421,7 +437,7 @@ func AuthorizeUpdateFromRepository(r *http.Request) (*Authorization, error) { } } - return nil, errors.Join(causes...) + return nil, joinErrors(causes...) } func AuthorizeRepository(repoURL string, auth *Authorization) error { @@ -511,7 +527,7 @@ func AuthorizeUpdateFromArchive(r *http.Request) (*Authorization, error) { return auth, nil } - return nil, errors.Join(causes...) + return nil, joinErrors(causes...) } func CheckForbiddenDomain(r *http.Request) error { diff --git a/src/pages.go b/src/pages.go index f858e8b..86d89e4 100644 --- a/src/pages.go +++ b/src/pages.go @@ -562,9 +562,7 @@ func ServePages(w http.ResponseWriter, r *http.Request) { if err != nil { var authErr AuthError if errors.As(err, &authErr) { - message := fmt.Sprint(err) - http.Error(w, strings.ReplaceAll(message, "\n", "\n- "), authErr.code) - err = errors.New(strings.ReplaceAll(message, "\n", "; ")) + http.Error(w, prettyErrMsg(err), authErr.code) } var tooLargeErr *http.MaxBytesError if errors.As(err, &tooLargeErr) { diff --git a/src/util.go b/src/util.go index 4d92a1c..6347dd9 100644 --- a/src/util.go +++ b/src/util.go @@ -1,6 +1,10 @@ package main -import "io" +import ( + "errors" + "io" + "strings" +) type BoundedReader struct { inner io.Reader @@ -23,3 +27,55 @@ func (reader *BoundedReader) Read(dest []byte) (count int, err error) { reader.fuel -= int64(count) return } + +type prettyError interface { + error + Pretty() string +} + +func prettyErrMsg(err error) string { + switch cerr := err.(type) { + case prettyError: + return cerr.Pretty() + default: + return cerr.Error() + } +} + +type prettyJoinError struct { + errs []error +} + +func joinErrors(errs ...error) error { + if err := errors.Join(errs...); err != nil { + wrapErr := err.(interface{ Unwrap() []error }) + return &prettyJoinError{errs: wrapErr.Unwrap()} + } + return nil +} + +func (e *prettyJoinError) Error() string { + var s strings.Builder + for i, err := range e.errs { + if i > 0 { + s.WriteString("; ") + } + s.WriteString(err.Error()) + } + return s.String() +} + +func (e *prettyJoinError) Pretty() string { + var s strings.Builder + for i, err := range e.errs { + if i > 0 { + s.WriteString("\n- ") + } + s.WriteString(strings.ReplaceAll(prettyErrMsg(err), "\n", "\n ")) + } + return s.String() +} + +func (e *prettyJoinError) Unwrap() []error { + return e.errs +}