diff --git a/config.go b/config.go index e65ce27..34b4c3c 100644 --- a/config.go +++ b/config.go @@ -19,11 +19,12 @@ type Config struct { } `yaml:"ssl"` } `yaml:"server"` Hosts []struct { - Port int `yaml:"port"` - Remotes []string `yaml:"remotes"` - Domains []string `yaml:"domains"` - Secure bool `yaml:"secure"` - Rewrite map[string]string `yaml:"rewrite"` + Port int `yaml:"port"` + Remotes []string `yaml:"remotes"` + Domains []string `yaml:"domains"` + Secure bool `yaml:"secure"` + Rewrite map[string]string `yaml:"rewrite"` + AdditionalHeaders map[string]string `yaml:"extraHeaders"` } `yaml:"hosts"` RateLimit struct { Enabled bool `yaml:"enabled"` diff --git a/router.go b/router.go index 8dc09c7..6add31c 100644 --- a/router.go +++ b/router.go @@ -24,18 +24,19 @@ type Router struct { } type Host struct { - Port int - Remotes []string - Secure bool - Current *atomic.Uint32 - Rewrites map[string]*Host + Port int + Remotes []string + Secure bool + Current *atomic.Uint32 + Rewrites map[string]*Host + AdditionalHeaders map[string]string } func New(config *Config, client *http.Client) Router { m := make(map[string]Host) for _, host := range config.Hosts { for _, domain := range host.Domains { - curr := Host{host.Port, host.Remotes, host.Secure, &atomic.Uint32{}, make(map[string]*Host)} + curr := Host{host.Port, host.Remotes, host.Secure, &atomic.Uint32{}, make(map[string]*Host), host.AdditionalHeaders} m[domain] = curr for subUrl, rewriteHost := range host.Rewrite { @@ -98,17 +99,6 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { } for subUrl, rewriteHost := range host.Rewrites { - // parts := strings.Split(subUrl, "/") - // requestParts := strings.Split(r.URL.Path, "/") - // - // for i, part := range parts { - // if !strings.EqualFold(part, requestParts[i]) { - // break - // } - // } - // - // slicedPath := "/" + strings.Join(requestParts[len(parts):], "/") - if !strings.HasPrefix(r.URL.Path, subUrl) { break } @@ -122,7 +112,7 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { Str("requested_path", r.URL.Path). Str("new_path", slicedPath). Msg("Rewriting matched url path to different remote") - + r.URL.Path = slicedPath host = *rewriteHost break @@ -212,7 +202,11 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { removeHopByHopHeaders(res.Header) - copyHeader(w.Header(), res.Header) + resultHeader := w.Header() + copyHeader(resultHeader, res.Header) + for name, val := range host.AdditionalHeaders { + resultHeader.Add(name, val) + } w.WriteHeader(res.StatusCode) err = router.copyResponse(w, res.Body)