diff --git a/src/caddy.go b/src/caddy.go index 74761c5..9f39544 100644 --- a/src/caddy.go +++ b/src/caddy.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/url" "strings" ) @@ -55,14 +54,11 @@ func ServeCaddy(w http.ResponseWriter, r *http.Request) { } func tryDialWithSNI(ctx context.Context, domain string) (bool, error) { - if config.Fallback.ProxyTo == "" { + if config.Fallback.ProxyTo == nil { return false, nil } - fallbackURL, err := url.Parse(config.Fallback.ProxyTo) - if err != nil { - return false, err - } + fallbackURL := config.Fallback.ProxyTo if fallbackURL.Scheme != "https" { return false, nil } diff --git a/src/config.go b/src/config.go index 72e0530..9303b7e 100644 --- a/src/config.go +++ b/src/config.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "net/url" "os" "reflect" "slices" @@ -16,7 +17,7 @@ import ( "github.com/pelletier/go-toml/v2" ) -// For some reason, the standard `time.Duration` type doesn't implement the standard +// For an unknown reason, the standard `time.Duration` type doesn't implement the standard // `encoding.{TextMarshaler,TextUnmarshaler}` interfaces. type Duration time.Duration @@ -26,7 +27,9 @@ func (t Duration) String() string { func (t *Duration) UnmarshalText(data []byte) (err error) { u, err := time.ParseDuration(string(data)) - *t = Duration(u) + if err == nil { + *t = Duration(u) + } return } @@ -34,6 +37,28 @@ func (t *Duration) MarshalText() ([]byte, error) { return []byte(t.String()), nil } +// For a known but upsetting reason, the standard `url.URL` type doesn't implement the standard +// `encoding.{TextMarshaler,TextUnmarshaler}` interfaces. +type URL struct { + url.URL +} + +func (t *URL) String() string { + return fmt.Sprint(&t.URL) +} + +func (t *URL) UnmarshalText(data []byte) (err error) { + u, err := url.Parse(string(data)) + if err == nil { + *t = URL{*u} + } + return +} + +func (t *URL) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + type Config struct { Insecure bool `toml:"-" env:"insecure"` Features []string `toml:"features"` @@ -56,15 +81,15 @@ type ServerConfig struct { type WildcardConfig struct { Domain string `toml:"domain"` - CloneURL string `toml:"clone-url"` + CloneURL string `toml:"clone-url"` // URL template, not an exact URL IndexRepos []string `toml:"index-repos" default:"[]"` IndexRepoBranch string `toml:"index-repo-branch" default:"pages"` Authorization string `toml:"authorization"` } type FallbackConfig struct { - ProxyTo string `toml:"proxy-to"` - Insecure bool `toml:"insecure"` + ProxyTo *URL `toml:"proxy-to"` + Insecure bool `toml:"insecure"` } type CacheConfig struct { @@ -223,15 +248,20 @@ func setConfigValue(reflValue reflect.Value, repr string) (err error) { if valueCast, err = datasize.ParseString(repr); err == nil { reflValue.Set(reflect.ValueOf(valueCast)) } - case time.Duration: - if valueCast, err = time.ParseDuration(repr); err == nil { - reflValue.Set(reflect.ValueOf(valueCast)) - } case Duration: var parsed time.Duration if parsed, err = time.ParseDuration(repr); err == nil { reflValue.Set(reflect.ValueOf(Duration(parsed))) } + case *URL: + if repr == "" { + reflValue.Set(reflect.ValueOf(nil)) + } else { + var parsed *url.URL + if parsed, err = url.Parse(repr); err == nil { + reflValue.Set(reflect.ValueOf(&URL{*parsed})) + } + } case []WildcardConfig: var parsed []*WildcardConfig decoder := json.NewDecoder(bytes.NewReader([]byte(repr))) diff --git a/src/main.go b/src/main.go index 21948af..6744184 100644 --- a/src/main.go +++ b/src/main.go @@ -67,14 +67,8 @@ func configureWildcards(_ context.Context) (err error) { } func configureFallback(_ context.Context) (err error) { - if config.Fallback.ProxyTo != "" { - var fallbackURL *url.URL - fallbackURL, err = url.Parse(config.Fallback.ProxyTo) - if err != nil { - err = fmt.Errorf("fallback: %w", err) - return - } - + if config.Fallback.ProxyTo != nil { + fallbackURL := &config.Fallback.ProxyTo.URL fallback = &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(fallbackURL)