Add a domain cache to quickly reject non-existent domains.

This commit is contained in:
miyuko
2026-04-11 12:00:20 +00:00
parent f400f8d246
commit bbdaae7280
12 changed files with 225 additions and 5 deletions

View File

@@ -138,13 +138,16 @@ type Backend interface {
// Create a domain. This allows us to start serving content for the domain.
CreateDomain(ctx context.Context, domain string) error
// Freeze a domain. This allows a site to be administratively locked, e.g. if it
// Freeze a domain. This allows a site to be administratively locked, e.g. if it
// is discovered serving abusive content.
FreezeDomain(ctx context.Context, domain string) error
// Thaw a domain. This removes the previously placed administrative lock (if any).
UnfreezeDomain(ctx context.Context, domain string) error
// Check whether the set of domains we serve has changed since the time passed to this method.
HaveDomainsChanged(ctx context.Context, since time.Time) (changed bool, err error)
// Append a record to the audit log.
AppendAuditLog(ctx context.Context, id AuditID, record *AuditRecord) error

View File

@@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
"time"
)
type FSBackend struct {
@@ -479,6 +480,10 @@ func (fs *FSBackend) UnfreezeDomain(ctx context.Context, domain string) error {
}
}
func (fs *FSBackend) HaveDomainsChanged(ctx context.Context, since time.Time) (bool, error) {
return true, nil // not implemented
}
func (fs *FSBackend) AppendAuditLog(ctx context.Context, id AuditID, record *AuditRecord) error {
if _, err := fs.auditRoot.Stat(id.String()); err == nil {
panic(fmt.Errorf("audit ID collision: %s", id))

View File

@@ -643,8 +643,11 @@ func (s3 *S3Backend) DeleteManifest(
err := s3.client.RemoveObject(ctx, s3.bucket, manifestObjectName(name),
minio.RemoveObjectOptions{})
if err != nil {
return err
}
s3.siteCache.Cache.Invalidate(name)
return err
return s3.bumpLastDomainUpdateTimestamp(ctx)
}
func (s3 *S3Backend) EnumerateManifests(ctx context.Context) iter.Seq2[*ManifestMetadata, error] {
@@ -764,8 +767,19 @@ func (s3 *S3Backend) CheckDomain(ctx context.Context, domain string) (exists boo
func (s3 *S3Backend) CreateDomain(ctx context.Context, domain string) error {
logc.Printf(ctx, "s3: create domain %s\n", domain)
_, err := s3.client.PutObject(ctx, s3.bucket, domainCheckObjectName(domain),
exists, err := s3.CheckDomain(ctx, domain)
if err != nil {
return err
}
_, err = s3.client.PutObject(ctx, s3.bucket, domainCheckObjectName(domain),
&bytes.Reader{}, 0, minio.PutObjectOptions{})
if err != nil {
return err
}
if !exists {
err = s3.bumpLastDomainUpdateTimestamp(ctx)
}
return err
}
@@ -790,6 +804,25 @@ func (s3 *S3Backend) UnfreezeDomain(ctx context.Context, domain string) error {
}
}
const lastDomainUpdateObjectName = "meta/last-domain-update"
func (s3 *S3Backend) HaveDomainsChanged(ctx context.Context, since time.Time) (bool, error) {
info, err := s3.client.StatObject(ctx, s3.bucket, lastDomainUpdateObjectName,
minio.GetObjectOptions{})
if err != nil {
return false, err
}
return info.LastModified.After(since), nil
}
func (s3 *S3Backend) bumpLastDomainUpdateTimestamp(ctx context.Context) error {
logc.Print(ctx, "s3: bumping last domain update timestamp")
_, err := s3.client.PutObject(ctx, s3.bucket, lastDomainUpdateObjectName,
&bytes.Reader{}, 0, minio.PutObjectOptions{})
return err
}
func auditObjectName(id AuditID) string {
return fmt.Sprintf("audit/%s", id)
}

View File

@@ -26,7 +26,17 @@ func ServeCaddy(w http.ResponseWriter, r *http.Request) {
return
}
found, err := backend.CheckDomain(r.Context(), strings.ToLower(domain))
var err error
domain = strings.ToLower(domain)
// Run a cheap check as to whether we might be serving the domain.
var found = domainCache.CheckDomain(r.Context(), domain)
if !found {
// Run an expensive check as to whether we are actually serving the domain.
found, err = backend.CheckDomain(r.Context(), domain)
}
if !found {
// If we don't serve the domain, but a fallback server does, then we should let our
// Caddy instance request a TLS certificate. Otherwise, we'll never have an opportunity

132
src/domain_cache.go Normal file
View File

@@ -0,0 +1,132 @@
package git_pages
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/bits-and-blooms/bloom/v3"
)
type DomainCache interface {
// Check if we might be serving the domain.
CheckDomain(ctx context.Context, domain string) (found bool)
// Add the domain to the cache.
AddDomain(ctx context.Context, domain string)
}
func CreateDomainCache(ctx context.Context) (DomainCache, error) {
if !config.Feature("domain-existence-cache") {
return &dummyDomainCache{}, nil
}
return createBloomDomainCache(ctx)
}
type bloomDomainCache struct {
filter *bloom.BloomFilter
filterMu sync.Mutex
accessCh chan struct{}
refreshMu sync.Mutex
lastRefresh time.Time
maxAge time.Duration
}
func createBloomDomainCache(ctx context.Context) (DomainCache, error) {
cache := bloomDomainCache{
accessCh: make(chan struct{}),
}
switch config.Storage.Type {
case "fs":
// the FS backend has no cache
case "s3":
cache.maxAge = time.Duration(config.Storage.S3.SiteCache.MaxAge)
default:
panic(fmt.Errorf("unknown backend: %s", config.Storage.Type))
}
if err := cache.refresh(ctx); err != nil {
return nil, err
}
go cache.handleFilterUpdates(ctx)
return &cache, nil
}
func (c *bloomDomainCache) handleFilterUpdates(ctx context.Context) {
for range c.accessCh {
if time.Since(c.lastRefresh) > c.maxAge {
logc.Print(ctx, "domain cache: refreshing")
if err := c.refresh(ctx); err != nil {
logc.Printf(ctx, "domain cache: refresh error: %v", err)
}
}
}
}
func (c *bloomDomainCache) refresh(ctx context.Context) error {
c.refreshMu.Lock()
defer c.refreshMu.Unlock()
if changed, err := backend.HaveDomainsChanged(ctx, c.lastRefresh); err != nil {
return err
} else if !changed {
logc.Print(ctx, "domain cache: unchanged")
c.lastRefresh = time.Now()
return nil
}
// Create a 256 KiB Bloom filter that will fit ~150K entries with 0.1% false positive rate.
filter := bloom.New(256*1024, 10)
for metadata, err := range backend.EnumerateManifests(ctx) {
if err != nil {
return fmt.Errorf("enum manifests: %w", err)
}
domain, _, _ := strings.Cut(metadata.Name, "/")
filter.AddString(domain)
}
c.filterMu.Lock()
c.filter = filter
c.filterMu.Unlock()
logc.Printf(ctx, "domain cache: refreshed with approx. %d domains", filter.ApproximatedSize())
c.lastRefresh = time.Now()
return nil
}
func (c *bloomDomainCache) CheckDomain(ctx context.Context, domain string) (found bool) {
select {
case c.accessCh <- struct{}{}:
default:
}
c.filterMu.Lock()
found = c.filter.TestString(domain)
c.filterMu.Unlock()
logc.Printf(ctx, "domain cache: bloom filter returns %v for %q", found, domain)
return
}
func (c *bloomDomainCache) AddDomain(ctx context.Context, domain string) {
c.refreshMu.Lock()
defer c.refreshMu.Unlock()
c.filterMu.Lock()
c.filter.AddString(domain)
c.filterMu.Unlock()
logc.Printf(ctx, "domain cache: added %q", domain)
}
type dummyDomainCache struct{}
func (d dummyDomainCache) CheckDomain(context.Context, string) bool { return true }
func (d dummyDomainCache) AddDomain(context.Context, string) {}

View File

@@ -33,6 +33,7 @@ var config *Config
var wildcards []*WildcardPattern
var fallback http.Handler
var backend Backend
var domainCache DomainCache
func configureFeatures(ctx context.Context) (err error) {
if len(config.Features) > 0 {
@@ -639,6 +640,10 @@ func Main(versionInfo string) {
}
backend = NewObservedBackend(backend)
if domainCache, err = CreateDomainCache(ctx); err != nil {
logc.Fatalln(ctx, err)
}
middleware := chainHTTPMiddleware(
panicHandler,
remoteAddrMiddleware,

View File

@@ -346,6 +346,13 @@ func (backend *observedBackend) UnfreezeDomain(ctx context.Context, domain strin
return
}
func (backend *observedBackend) HaveDomainsChanged(ctx context.Context, since time.Time) (changed bool, err error) {
span, ctx := ObserveFunction(ctx, "HaveDomainsChanged", "since", since)
changed, err = backend.inner.HaveDomainsChanged(ctx, since)
span.Finish()
return
}
func (backend *observedBackend) AppendAuditLog(ctx context.Context, id AuditID, record *AuditRecord) (err error) {
span, ctx := ObserveFunction(ctx, "AppendAuditLog", "audit.id", id)
err = backend.inner.AppendAuditLog(ctx, id, record)

View File

@@ -65,8 +65,12 @@ func observeSiteUpdate(via string, result *UpdateResult) {
}
}
func normalizeHost(host string) string {
return strings.ToLower(host)
}
func makeWebRoot(host string, projectName string) string {
return path.Join(strings.ToLower(host), projectName)
return path.Join(normalizeHost(host), projectName)
}
func getWebRoot(r *http.Request) (string, error) {
@@ -115,6 +119,13 @@ func getPage(w http.ResponseWriter, r *http.Request) error {
return err
}
host = normalizeHost(host)
if !domainCache.CheckDomain(r.Context(), host) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "site not found\n")
return nil
}
type indexManifestResult struct {
manifest *Manifest
metadata ManifestMetadata

View File

@@ -59,6 +59,7 @@ func Update(
if err == nil {
domain, _, _ := strings.Cut(webRoot, "/")
err = backend.CreateDomain(ctx, domain)
domainCache.AddDomain(ctx, domain)
}
if err == nil {
if oldManifest == nil {