Refactor authentication code.

This commit is contained in:
Catherine
2025-09-18 19:23:59 +00:00
parent 6f932df886
commit 3c46169ba6
2 changed files with 167 additions and 61 deletions

View File

@@ -3,7 +3,9 @@ package main
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
@@ -11,6 +13,23 @@ import (
"strings"
)
type AuthError struct {
code int
error string
}
func (e AuthError) Error() string {
return e.error
}
func IsUnauthorized(err error) bool {
var authErr AuthError
if errors.As(err, &authErr) {
return authErr.code == http.StatusUnauthorized
}
return false
}
func GetHost(r *http.Request) string {
// FIXME: handle IDNA
host, _, err := net.SplitHostPort(r.Host)
@@ -21,28 +40,39 @@ func GetHost(r *http.Request) string {
return host
}
func Authorize(w http.ResponseWriter, r *http.Request) error {
host := GetHost(r)
if os.Getenv("INSECURE") != "" {
return nil // for testing only
func GetProjectName(r *http.Request) (string, error) {
// path must be either `/` or `/foo/` (`/foo` is accepted as an alias)
path, _ := strings.CutPrefix(r.URL.Path, "/")
path, _ = strings.CutSuffix(path, "/")
if strings.HasPrefix(path, ".") {
return "", AuthError{http.StatusBadRequest, "directory name %s is reserved"}
} else if strings.Contains(path, "/") {
return "", AuthError{http.StatusBadRequest, "directories nested too deep"}
}
if path == "" {
// path `/` corresponds to pseudo-project `.index`
return ".index", nil
} else {
return path, nil
}
}
func authorizeDNSChallenge(r *http.Request) ([]string, error) {
host := GetHost(r)
authorization := r.Header.Get("Authorization")
if authorization == "" {
http.Error(w, "missing Authorization header", http.StatusUnauthorized)
return fmt.Errorf("missing Authorization header")
return nil, AuthError{http.StatusUnauthorized, "missing Authorization header"}
}
scheme, param, success := strings.Cut(authorization, " ")
if !success {
http.Error(w, "malformed Authorization header", http.StatusBadRequest)
return fmt.Errorf("malformed Authorization header")
return nil, AuthError{http.StatusBadRequest, "malformed Authorization header"}
}
if scheme != "Pages" && scheme != "Basic" {
http.Error(w, "unknown Authorization scheme", http.StatusBadRequest)
return fmt.Errorf("unknown Authorization scheme")
return nil, AuthError{http.StatusBadRequest, "unknown Authorization scheme"}
}
// services like GitHub and Gogs cannot send a custom Authorization: header, but supplying
@@ -50,19 +80,16 @@ func Authorize(w http.ResponseWriter, r *http.Request) error {
if scheme == "Basic" {
basicParam, err := base64.StdEncoding.DecodeString(param)
if err != nil {
http.Error(w, "malformed Authorization: Basic header", http.StatusBadRequest)
return fmt.Errorf("malformed Authorization: Basic header")
return nil, AuthError{http.StatusBadRequest, "malformed Authorization: Basic header"}
}
username, password, found := strings.Cut(string(basicParam), ":")
if !found {
http.Error(w, "malformed Authorization: Basic parameter", http.StatusBadRequest)
return fmt.Errorf("malformed Authorization: Basic parameter")
return nil, AuthError{http.StatusBadRequest, "malformed Authorization: Basic parameter"}
}
if username != "Pages" {
http.Error(w, "unexpected Authorization: Basic username", http.StatusUnauthorized)
return fmt.Errorf("unexpected Authorization: Basic username")
return nil, AuthError{http.StatusUnauthorized, "unexpected Authorization: Basic username"}
}
param = password
@@ -71,23 +98,109 @@ func Authorize(w http.ResponseWriter, r *http.Request) error {
challengeHostname := fmt.Sprintf("_git-pages-challenge.%s", host)
actualChallenges, err := net.LookupTXT(challengeHostname)
if err != nil {
http.Error(w, "failed to look up DNS challenge", http.StatusUnauthorized)
return fmt.Errorf("failed to look up %s: %w", challengeHostname, err)
return nil, AuthError{http.StatusUnauthorized,
fmt.Sprintf("failed to look up DNS challenge: TXT %s", challengeHostname)}
}
expectedChallenge := fmt.Sprintf("%x", sha256.Sum256(fmt.Appendf(nil, "%s %s", host, param)))
if !slices.Contains(actualChallenges, expectedChallenge) {
http.Error(w,
fmt.Sprintf("defeated by DNS challenge (%s not in %s)", expectedChallenge, challengeHostname),
http.StatusUnauthorized,
)
return fmt.Errorf(
"challenge mismatch for %s: %s does not contain %s",
return nil, AuthError{http.StatusUnauthorized, fmt.Sprintf(
"defeated by DNS challenge: TXT %s %v does not include %s",
challengeHostname,
actualChallenges,
expectedChallenge,
)
)}
}
return nil
return nil, nil
}
func authorizeWildcardDomain(r *http.Request) ([]string, error) {
host := GetHost(r)
hostParts := strings.Split(host, ".")
projectName, err := GetProjectName(r)
if err != nil {
return nil, err
}
if slices.Equal(hostParts[1:], strings.Split(config.Wildcard.Domain, ".")) {
userName := hostParts[0]
repoName := projectName
if repoName == ".index" {
repoName = fmt.Sprintf(config.Wildcard.IndexRepo, userName)
}
return []string{fmt.Sprintf(config.Wildcard.CloneURL, userName, repoName)}, nil
}
return nil, AuthError{
http.StatusUnauthorized,
fmt.Sprintf("domain %s does not match wildcard *.%s", host, config.Wildcard.Domain),
}
}
// Returns `repoURLs, err` where if `err == nil` then the request is authorized to clone from
// any repository URL exactly included in `repoURLs`, or any URL at all if `repoURLs == nil`.
func authorizeRequest(r *http.Request, allowWildcard bool) ([]string, error) {
causes := []error{AuthError{http.StatusUnauthorized, "unauthorized"}}
if os.Getenv("INSECURE") != "" {
log.Println("auth ok: INSECURE mode")
return nil, nil // for testing only
}
repoURLs, err := authorizeDNSChallenge(r)
if err != nil && IsUnauthorized(err) {
causes = append(causes, err)
} else if err != nil { // bad request
return nil, err
} else {
log.Println("auth ok: DNS challenge")
return repoURLs, nil
}
if allowWildcard {
repoURLs, err = authorizeWildcardDomain(r)
if err != nil && IsUnauthorized(err) {
causes = append(causes, err)
} else if err != nil { // bad request
return nil, err
} else {
log.Println("auth ok: wildcard *.%s: allow %v", config.Wildcard.Domain, repoURLs)
return repoURLs, nil
}
}
return nil, errors.Join(causes...)
}
func AuthorizeRequestWithWildcard(r *http.Request) ([]string, error) {
return authorizeRequest(r, true)
}
func AuthorizeRequestWithoutWildcard(r *http.Request) ([]string, error) {
return authorizeRequest(r, false)
}
func AuthorizeRepository(repoURL string, allowRepoURLs []string) error {
if allowRepoURLs == nil {
return nil // any
}
allowed := false
for _, allowRepoURL := range allowRepoURLs {
if strings.EqualFold(repoURL, allowRepoURL) {
allowed = true
break
}
}
if allowed {
return nil
} else {
return AuthError{
http.StatusUnauthorized,
fmt.Sprintf("clone URL not in allowlist %v", allowRepoURLs),
}
}
}

View File

@@ -4,13 +4,13 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path"
"slices"
"strings"
"time"
@@ -153,25 +153,30 @@ func getProjectName(w http.ResponseWriter, r *http.Request) (string, error) {
func putPage(w http.ResponseWriter, r *http.Request) error {
host := GetHost(r)
err := Authorize(w, r)
projectName, err := GetProjectName(r)
if err != nil {
return err
}
projectName, err := getProjectName(w, r)
allowedRepoURLs, err := AuthorizeRequestWithoutWildcard(r)
if err != nil {
return err
}
requestBody, err := io.ReadAll(r.Body)
// URLs have no length limit, but 64K seems enough for a repository URL
requestBody, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 65536))
if err != nil {
return fmt.Errorf("body read: %w", err)
}
// request body contains git repository URL `https://codeberg.org/...`
// request header X-Pages-Branch contains git branch, `pages` by default
webRoot := fmt.Sprintf("%s/%s", host, projectName)
// request body contains git repository URL
repoURL := string(requestBody)
if err := AuthorizeRepository(repoURL, allowedRepoURLs); err != nil {
return err
}
branch := r.Header.Get("X-Pages-Branch")
if branch == "" {
branch = "pages"
@@ -208,28 +213,15 @@ func putPage(w http.ResponseWriter, r *http.Request) error {
func postPage(w http.ResponseWriter, r *http.Request) error {
host := GetHost(r)
hostParts := strings.Split(host, ".")
projectName, err := getProjectName(w, r)
projectName, err := GetProjectName(r)
if err != nil {
return err
}
allowRepoURL := ""
if slices.Equal(hostParts[1:], strings.Split(config.Wildcard.Domain, ".")) {
// explicit authorization bypasses wildcard domain restrictions
if err := Authorize(w, r); err != nil {
userName := hostParts[0]
repoName := projectName
if repoName == ".index" {
repoName = fmt.Sprintf(config.Wildcard.IndexRepo, userName)
}
allowRepoURL = fmt.Sprintf(config.Wildcard.CloneURL, userName, repoName)
}
} else {
if err := Authorize(w, r); err != nil {
return err
}
allowedRepoURLs, err := AuthorizeRequestWithWildcard(r)
if err != nil {
return err
}
eventName := ""
@@ -261,7 +253,8 @@ func postPage(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("invalid content type")
}
requestBody, err := io.ReadAll(r.Body)
// Event payloads have no length limit, but events bigger than 16M seem excessive.
requestBody, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 16*1048576))
if err != nil {
return fmt.Errorf("body read: %w", err)
}
@@ -276,19 +269,15 @@ func postPage(w http.ResponseWriter, r *http.Request) error {
eventRef := event["ref"].(string)
if eventRef != "refs/heads/pages" {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "ref %s ignored\n", eventRef)
fmt.Fprintf(w, "ignored %s\n", eventRef)
return nil
}
webRoot := fmt.Sprintf("%s/%s", host, projectName)
repoURL := event["repository"].(map[string]any)["clone_url"].(string)
if allowRepoURL != "" && !strings.EqualFold(repoURL, allowRepoURL) {
http.Error(w,
fmt.Sprintf("wildcard domain requires repository to be %s", allowRepoURL),
http.StatusUnauthorized,
)
return fmt.Errorf("invalid clone URL")
if err := AuthorizeRepository(repoURL, allowedRepoURLs); err != nil {
return err
}
ctx, cancel := context.WithTimeout(r.Context(), updateTimeout)
@@ -330,10 +319,14 @@ func ServePages(w http.ResponseWriter, r *http.Request) {
err = fmt.Errorf("method %s not allowed", r.Method)
}
if err != nil {
if pathErr, ok := err.(*os.PathError); ok {
var authErr AuthError
if errors.As(err, &authErr) {
message := fmt.Sprint(err)
http.Error(w, strings.ReplaceAll(message, "\n", "\n- "), authErr.code)
err = errors.New(strings.ReplaceAll(message, "\n", "; "))
} else if pathErr, ok := err.(*os.PathError); ok {
err = fmt.Errorf("not found: %s", pathErr.Path)
}
if minioErr, ok := err.(minio.ErrorResponse); ok && minioErr.Code == "NoSuchKey" {
} else if minioErr, ok := err.(minio.ErrorResponse); ok && minioErr.Code == "NoSuchKey" {
err = fmt.Errorf("not found: %s", minioErr.Key)
}
log.Println("pages err:", err)