diff --git a/config.yaml b/config.yaml index b9a875a..560fb58 100644 --- a/config.yaml +++ b/config.yaml @@ -15,7 +15,7 @@ server: logging: - level: debug + level: trace # Pretty print for human consumption otherwise json pretty: true # Log incoming requests @@ -70,6 +70,12 @@ hosts: domains: - chat.localhost + - remotes: + - localhost + port: 8080 + domains: + - gorilla.localhost + - remotes: - www.google.com port: 443 diff --git a/go.mod b/go.mod index ebaf4ff..2dc041b 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( require ( github.com/go-acme/lego v2.7.2+incompatible github.com/go-acme/lego/v4 v4.24.0 + github.com/google/uuid v1.6.0 github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/urfave/negroni v1.0.0 diff --git a/go.sum b/go.sum index 862742e..b879d78 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/go-acme/lego/v4 v4.24.0/go.mod h1:hkstZY6D0jylIrZbuNmEQrWQxTIfaJH7prw github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/middleware/logging.go b/middleware/logging.go index 1d09918..8ead220 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -4,23 +4,39 @@ import ( "net/http" "time" + "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/urfave/negroni" ) func RequestLogger(next http.Handler) http.Handler { - log.Info().Msg("Enabling Logging") + log.Info().Msg("Enabling Logging") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := negroni.NewResponseWriter(w) + + uuid := uuid.New().String() + log.Info(). + Str("host", r.Host). + Str("uri", r.RequestURI). + Str("method", r.Method). + Str("uuid", uuid). + Msg("Received Request") + next.ServeHTTP(lrw, r) duration := time.Since(start) if duration.Milliseconds() > 500 { - log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Int("status", lrw.Status()).Int("size", lrw.Size()).Str("duration", duration.String()).Msg("Slow Request") - } else { - log.Info().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Int("status", lrw.Status()).Int("size", lrw.Size()).Str("duration", duration.String()).Msg("Received Request") + log.Warn(). + Str("host", r.Host). + Str("uri", r.RequestURI). + Str("method", r.Method). + Int("status", lrw.Status()). + Int("size", lrw.Size()). + Str("duration", duration.String()). + Str("uuid", uuid). + Msg("Slow Request") } }) } diff --git a/router.go b/router.go index daeb3d5..8e156bb 100644 --- a/router.go +++ b/router.go @@ -6,10 +6,12 @@ import ( "fmt" "io" "net/http" - "net/http/httputil" + "net/http/httptrace" + "net/textproto" "net/url" "slices" "strings" + "sync" "sync/atomic" "github.com/rs/zerolog/log" @@ -52,7 +54,7 @@ func (router *Router) roundRobin(host *Host) { } } -func rewriteRequestURL(r *http.Request, host *Host, remote string, ) error { +func rewriteRequestURL(r *http.Request, host *Host, remote string) error { subUrlPath := r.URL.RequestURI() var uri string if host.Secure { @@ -76,14 +78,11 @@ func rewriteRequestURL(r *http.Request, host *Host, remote string, ) error { } func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // If trace enabled dump incoming request, could break request so exit early if that happens - if !dumpRequest(w, r) { - return - } - - host, ok := router.domains.Get(r.Host) + transport := http.DefaultTransport + portLessHost, _ := strings.CutSuffix(r.Host, fmt.Sprintf(":%d", router.config.Server.Port)) + host, ok := router.domains.Get(portLessHost) if !ok { - log.Warn().Str("host", r.Host).Msg("Could not find Host") + log.Warn().Str("host", portLessHost).Msg("Could not find Host") w.WriteHeader(http.StatusOK) return } @@ -102,6 +101,7 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if outreq.Body != nil { defer outreq.Body.Close() } + outreq.Close = false reqUpType := upgradeType(outreq.Header) if !isPrintableAscii(reqUpType) { @@ -117,7 +117,7 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if reqUpType != "" { outreq.Header.Set("Connection", "Upgrade") outreq.Header.Set("Upgrade", reqUpType) - log.Debug().Str("upgrade", reqUpType).Msg("Request upgrade") + log.Trace().Str("upgrade_type", reqUpType).Msg("Found upgrade Type") } stripClientProvidedXForwardHeaders(outreq.Header) @@ -133,20 +133,34 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Dump created request - if !dumpRequest(w, outreq) { - return - } + var ( + roundTripMutex sync.Mutex + roundTripDone bool + ) - res, err := router.client.Do(outreq) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMutex.Lock() + defer roundTripMutex.Unlock() + if roundTripDone { + return nil + } + h := w.Header() + copyHeader(h, http.Header(header)) + w.WriteHeader(code) + + clear(h) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + roundTripMutex.Lock() + roundTripDone = true + roundTripMutex.Unlock() if err != nil { - log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request") - w.WriteHeader(http.StatusInternalServerError) - return - } - - // If trace enabled dump response - if !dumpResponse(w, res) { + log.Error().Err(err).Any("out_request", outreq).Msg("Could not complete transport round trip") return } @@ -266,7 +280,6 @@ func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { if _, err := io.Copy(c.user, c.backend); err != nil { errc <- err return - } // backend conn has reached EOF so propogate close write to user conn @@ -316,36 +329,15 @@ func isPrintableAscii(reqUpType string) bool { } func upgradeType(header http.Header) string { - for _, val := range header.Values("Connection") { - if strings.ToLower(val) == "upgrade" { - return header.Get("Upgrade") + // Iterate over Connection headers if those exist multiple times + for _, conVal := range header.Values("Connection") { + for _, headerVal := range strings.Split(conVal, ",") { + trimmed := strings.TrimSpace(headerVal) + if strings.EqualFold(trimmed, "upgrade") { + upType := header.Get("Upgrade") + return upType + } } } return "" } - -func dumpRequest(w http.ResponseWriter, r *http.Request) bool { - if e := log.Trace(); e.Enabled() { - rDump, err := httputil.DumpRequest(r, true) - if err != nil { - log.Error().Err(err).Msg("Could not dump request") - w.WriteHeader(http.StatusInternalServerError) - return false - } - log.Trace().Str("dump", string(rDump)).Msg("Dumping Request") - } - return true -} - -func dumpResponse(w http.ResponseWriter, r *http.Response) bool { - if e := log.Trace(); e.Enabled() { - dump, err := httputil.DumpResponse(r, true) - if err != nil { - log.Error().Err(err).Msg("Could not dump response") - w.WriteHeader(http.StatusInternalServerError) - return false - } - log.Trace().Str("dump", string(dump)).Msg("Dumping Response") - } - return true -}