diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index 5ef44f5..5f1b96c 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -40,7 +40,7 @@ func main() { router := domainrouter.New(config, client) mux := http.NewServeMux() - mux.HandleFunc("/", router.Route) + mux.HandleFunc("/", router.ServeHTTP) if config.General.AnnouncePublic { h, err := url.JoinPath("/", config.General.HealthEndpoint) diff --git a/config.yaml b/config.yaml index 4b0c778..c4fe7f1 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ server: certFile: server.crt keyFile: server.key acme: - enabled: true + enabled: false email: me@pablu.de keyFile: userKey.key caDirUrl: https://192.168.2.154:14000/dir @@ -29,7 +29,7 @@ logging: rateLimit: - enabled: true + enabled: false # How many requests per ip adress are allowed bucketSize: 50 # How many requests per ip address are refilled @@ -44,7 +44,6 @@ hosts: # Remote address to request - remotes: - localhost - - 192.168.2.154 # Port on which to request port: 8181 # Health check if announce is true @@ -61,6 +60,14 @@ hosts: domains: - private.localhost + - remotes: + - localhost + port: 5173 + public: false + domains: + - hitstar.localhost + - hipstar.localhost + # - remotes: # - www.google.com # - localhost diff --git a/router.go b/router.go index 0a5d21a..525d44b 100644 --- a/router.go +++ b/router.go @@ -1,12 +1,16 @@ package domainrouter import ( + "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httputil" + "net/url" + "slices" + "strings" "sync/atomic" "github.com/rs/zerolog/log" @@ -103,87 +107,30 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func createRequest(r *http.Request, host *Host, remote string) (*http.Request, error) { +func rewriteRequestURL(r *http.Request, host *Host, remote string) error { subUrlPath := r.URL.RequestURI() - var url string + var uri string if host.Secure { - url = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath) + uri = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath) } else { - url = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath) + uri = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath) } - req, err := http.NewRequest(r.Method, url, r.Body) - if err != nil { - return nil, err - } - - copyRequestHeader(r, req) - req.Header.Set("X-Forwarded-For", r.RemoteAddr) - req.Header.Set("Cache-Control", "no-store, no-cache, max-age=0, must-revalidate, proxy-revalidate") - - for _, cookie := range r.Cookies() { - req.AddCookie(cookie) - } - - return req, nil -} - -var etagHeaders = map[string]struct{}{ - "ETag": {}, - "If-Modified-Since": {}, - "If-Match": {}, - "If-None-Match": {}, - "If-Range": {}, - "If-Unmodified-Since": {}, -} - -func copyRequestHeader(origin *http.Request, destination *http.Request) { - for name, values := range origin.Header { - - // Skip etag Headers - if _, ok := etagHeaders[name]; ok { - continue - } - - for _, value := range values { - destination.Header.Set(name, value) - } - } -} - -func applyResponseHeader(w http.ResponseWriter, res *http.Response) { - for name, values := range res.Header { - for _, value := range values { - w.Header().Set(name, value) - } - } - w.WriteHeader(res.StatusCode) -} - -func applyCookies(w http.ResponseWriter, res *http.Response) { - cookies := res.Cookies() - for _, cookie := range cookies { - http.SetCookie(w, cookie) - } -} - -func applyBody(w http.ResponseWriter, res *http.Response) error { - body, err := io.ReadAll(res.Body) - defer res.Body.Close() + remoteUrl, err := url.Parse(uri) if err != nil { return err } - _, err = w.Write(body) - if err != nil { - return err - } + r.RequestURI = "" + r.URL.Scheme = remoteUrl.Scheme + r.URL.Host = remoteUrl.Host + r.Header.Set("X-Forwarded-For", r.RemoteAddr) + r.Header.Set("Cache-Control", "no-store, no-cache, max-age=0, must-revalidate, proxy-revalidate") - w.WriteHeader(res.StatusCode) return nil } -func (router *Router) Route(w http.ResponseWriter, r *http.Request) { +func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If trace enabled dump incoming request, could break request so exit early if that happens if !dumpRequest(w, r) { return @@ -199,7 +146,41 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { remote := host.Remotes[host.Current.Load()] go router.roundRobin(&host) - req, err := createRequest(r, &host, remote) + // Copy request + // Copy body with Buffer Pool + ctx := r.Context() + + outreq := r.Clone(ctx) + if r.ContentLength == 0 { + outreq.Body = nil + } + if outreq.Body != nil { + defer outreq.Body.Close() + } + + reqUpType := upgradeType(outreq.Header) + if !isPrintableAscii(reqUpType) { + log.Error().Str("request_upgrade_type", reqUpType).Msg("Client tried to switch to invalid protocol") + return + } + removeHopByHopHeaders(outreq.Header) + + if slices.Contains(r.Header.Values("Te"), "trailers") { + outreq.Header.Set("Te", "trailers") + } + + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + stripClientProvidedXForwardHeaders(outreq.Header) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + err := rewriteRequestURL(outreq, &host, remote) if err != nil { log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request") w.WriteHeader(http.StatusInternalServerError) @@ -207,11 +188,11 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { } // Dump created request - if !dumpRequest(w, req) { + if !dumpRequest(w, outreq) { return } - res, err := router.client.Do(req) + res, err := router.client.Do(outreq) if err != nil { log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request") w.WriteHeader(http.StatusInternalServerError) @@ -223,35 +204,180 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { return } - applyCookies(w, res) - applyResponseHeader(w, res) - - // Exit early because its a redirect - // Maybe this should be before applying cookies or after applying headers - if !handleLocation(w, r, res) { + if res.StatusCode == http.StatusSwitchingProtocols { + router.handleUpgradeResponse(w, res, outreq) return } - err = applyBody(w, res) - if err != nil { - log.Error().Err(err).Msg("Could not apply body") - w.WriteHeader(http.StatusInternalServerError) - return + removeHopByHopHeaders(res.Header) + + copyHeader(w.Header(), res.Header) + + w.WriteHeader(res.StatusCode) + err = router.copyResponse(w, res.Body) + res.Body.Close() +} + +func (router *Router) copyResponse(dst http.ResponseWriter, src io.ReadCloser) error { + buf := make([]byte, 32*1024) + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + log.Error().Err(rerr).Msg("Could not copy body") + return rerr + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return werr + } + } + if rerr != nil && rerr == io.EOF { + return nil + } } } -func handleLocation(w http.ResponseWriter, r *http.Request, res *http.Response) bool { - if loc, err := res.Location(); err == nil { - http.Redirect(w, r, loc.String(), http.StatusFound) - return false - } else if !errors.Is(err, http.ErrNoLocation) { - log.Error().Err(err).Msg("Could not extract location") - w.WriteHeader(http.StatusInternalServerError) - return false +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func (router *Router) handleUpgradeResponse(w http.ResponseWriter, res *http.Response, req *http.Request) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if !strings.EqualFold(reqUpType, resUpType) { + log.Error().Str("response_upgrade_type", resUpType).Str("request_upgrade_type", reqUpType).Msg("Response and Request Upgrade type do not match") + return + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + log.Error().Msg("Could not switch protocols with non writable body") + return + } + + rc := http.NewResponseController(w) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + log.Error().Type("response_writer_type", w).Msg("Could not switch protocols using non-Hijacker ResponseWriter") + return + } + + backConnCloseCh := make(chan bool) + go func() { + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + + if hijackErr != nil { + log.Error().Err(hijackErr).Msg("Hijack failed on protocol switch") + return + } + defer conn.Close() + + copyHeader(w.Header(), res.Header) + + res.Header = w.Header() + res.Body = nil + if err := res.Write(brw); err != nil { + log.Error().Err(err).Msg("Could not write") + return + } + if err := brw.Flush(); err != nil { + log.Error().Err(err).Msg("Could not flush") + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + + err := <-errc + if err == nil { + err = <-errc + } +} + +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +var errCopyDone = errors.New("hijacked connection copy complete") + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + if _, err := io.Copy(c.user, c.backend); err != nil { + errc <- err + return + + } + + // backend conn has reached EOF so propogate close write to user conn + if wc, ok := c.user.(interface{ CloseWrite() error }); ok { + errc <- wc.CloseWrite() + return + } + errc <- errCopyDone +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + if _, err := io.Copy(c.backend, c.user); err != nil { + errc <- err + return + } + // user conn has reached EOF so propogate close write to backend conn + if wc, ok := c.backend.(interface{ CloseWrite() error }); ok { + errc <- wc.CloseWrite() + return + } + + errc <- errCopyDone +} + +func stripClientProvidedXForwardHeaders(header http.Header) { + header.Del("Forwarded") + header.Del("X-Forwarded-For") + header.Del("X-Forwarded-Host") + header.Del("X-Forwarded-Proto") +} + +func removeHopByHopHeaders(header http.Header) { + for _, f := range header.Values("Connection") { + if strings.TrimSpace(f) != "" { + header.Del(f) + } + } +} + +func isPrintableAscii(reqUpType string) bool { + for _, c := range reqUpType { + if c < 32 && c > 126 { + return false + } } return true } +func upgradeType(header http.Header) string { + for _, val := range header.Values("Connection") { + if strings.ToLower(val) == "upgrade" { + return header.Get("Upgrade") + } + } + return "" +} + func dumpRequest(w http.ResponseWriter, r *http.Request) bool { if e := log.Trace(); e.Enabled() && r.Method == "POST" { rDump, err := httputil.DumpRequest(r, true)