diff --git a/conf/config.example.toml b/conf/config.example.toml index 3a716a8..bb37eb1 100644 --- a/conf/config.example.toml +++ b/conf/config.example.toml @@ -57,7 +57,7 @@ allowed-custom-headers = ["X-Clacks-Overhead"] [audit] node-id = 0 collect = false -include-ip = false +include-ip = "" notify-url = "" [observability] diff --git a/src/config.go b/src/config.go index e2cd683..acc1c66 100644 --- a/src/config.go +++ b/src/config.go @@ -152,8 +152,11 @@ type AuditConfig struct { NodeID int `toml:"node-id"` // Whether audit reports should be stored whenever an audit event occurs. Collect bool `toml:"collect"` - // Whether audit reports should include principal's IP address. - IncludeIPs bool `toml:"include-ip"` + // If not empty, includes the principal's IP address in audit reports, with the value specifying + // the source of the IP address. If the value is "X-Forwarded-For", the last item of the + // corresponding header field (assumed to be comma-separated) is used. If the value is + // "RemoteAddr", the connecting host's address is used. Any other value is disallowed. + IncludeIPs string `toml:"include-ip"` // Endpoint to notify with a `GET /?` whenever an audit event occurs. NotifyURL *URL `toml:"notify-url"` } diff --git a/src/http.go b/src/http.go index 2effa15..6095fec 100644 --- a/src/http.go +++ b/src/http.go @@ -2,6 +2,9 @@ package git_pages import ( "cmp" + "fmt" + "net" + "net/http" "regexp" "slices" "strconv" @@ -129,3 +132,46 @@ func (e *HTTPEncodings) Negotiate(offers ...string) string { } return preferredAcceptOffer(encs) } + +func chainHTTPMiddleware(middleware ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + for idx := len(middleware) - 1; idx >= 0; idx-- { + handler = middleware[idx](handler) + } + return handler + } +} + +func remoteAddrMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var readXForwardedFor bool + switch config.Audit.IncludeIPs { + case "X-Forwarded-For": + readXForwardedFor = true + case "RemoteAddr", "": + readXForwardedFor = false + default: + panic(fmt.Errorf("config.Audit.IncludeIPs is set to an unknown value (%q)", + config.Audit.IncludeIPs)) + } + + usingOriginalRemoteAddr := true + if readXForwardedFor { + forwardedFor := strings.Split(r.Header.Get("X-Forwarded-For"), ",") + if len(forwardedFor) > 0 { + remoteAddr := strings.TrimSpace(forwardedFor[len(forwardedFor)-1]) + if remoteAddr != "" { + r.RemoteAddr = remoteAddr + usingOriginalRemoteAddr = false + } + } + } + if usingOriginalRemoteAddr { + if ipAddress, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + r.RemoteAddr = ipAddress + } + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/src/main.go b/src/main.go index 96f4058..f5c1b11 100644 --- a/src/main.go +++ b/src/main.go @@ -132,8 +132,6 @@ func panicHandler(handler http.Handler) http.Handler { func serve(ctx context.Context, listener net.Listener, handler http.Handler) { if listener != nil { - handler = panicHandler(handler) - server := http.Server{Handler: handler} server.Protocols = new(http.Protocols) server.Protocols.SetHTTP1(true) @@ -537,8 +535,13 @@ func Main() { } backend = NewObservedBackend(backend) - go serve(ctx, pagesListener, ObserveHTTPHandler(http.HandlerFunc(ServePages))) - go serve(ctx, caddyListener, ObserveHTTPHandler(http.HandlerFunc(ServeCaddy))) + middleware := chainHTTPMiddleware( + panicHandler, + remoteAddrMiddleware, + ObserveHTTPHandler, + ) + go serve(ctx, pagesListener, middleware(http.HandlerFunc(ServePages))) + go serve(ctx, caddyListener, middleware(http.HandlerFunc(ServeCaddy))) go serve(ctx, metricsListener, promhttp.Handler()) if config.Insecure { diff --git a/src/pages.go b/src/pages.go index 7f65c89..4fcaea1 100644 --- a/src/pages.go +++ b/src/pages.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "maps" - "net" "net/http" "net/url" "os" @@ -802,10 +801,8 @@ func postPage(w http.ResponseWriter, r *http.Request) error { func ServePages(w http.ResponseWriter, r *http.Request) { r = r.WithContext(WithPrincipal(r.Context())) - if config.Audit.IncludeIPs { - if ipAddress, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { - GetPrincipal(r.Context()).IpAddress = proto.String(ipAddress) - } + if config.Audit.IncludeIPs != "" { + GetPrincipal(r.Context()).IpAddress = proto.String(r.RemoteAddr) } // We want upstream health checks to be done as closely to the normal flow as possible; // any intentional deviation is an opportunity to miss an issue that will affect our