Fix Tls, fix logging and add uuid for correlation of websocket requests

This commit is contained in:
Pablu23
2025-09-30 21:56:32 +02:00
parent 88fa68fa4c
commit 018d9a9022
5 changed files with 74 additions and 57 deletions

View File

@@ -15,7 +15,7 @@ server:
logging:
level: debug
level: trace
# Pretty print for human consumption otherwise json
pretty: true
# Log incoming requests
@@ -70,6 +70,12 @@ hosts:
domains:
- chat.localhost
- remotes:
- localhost
port: 8080
domains:
- gorilla.localhost
- remotes:
- www.google.com
port: 443

1
go.mod
View File

@@ -23,6 +23,7 @@ require (
require (
github.com/go-acme/lego v2.7.2+incompatible
github.com/go-acme/lego/v4 v4.24.0
github.com/google/uuid v1.6.0
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/urfave/negroni v1.0.0

2
go.sum
View File

@@ -10,6 +10,8 @@ github.com/go-acme/lego/v4 v4.24.0/go.mod h1:hkstZY6D0jylIrZbuNmEQrWQxTIfaJH7prw
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/urfave/negroni"
)
@@ -14,13 +15,28 @@ func RequestLogger(next http.Handler) http.Handler {
start := time.Now()
lrw := negroni.NewResponseWriter(w)
uuid := uuid.New().String()
log.Info().
Str("host", r.Host).
Str("uri", r.RequestURI).
Str("method", r.Method).
Str("uuid", uuid).
Msg("Received Request")
next.ServeHTTP(lrw, r)
duration := time.Since(start)
if duration.Milliseconds() > 500 {
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Int("status", lrw.Status()).Int("size", lrw.Size()).Str("duration", duration.String()).Msg("Slow Request")
} else {
log.Info().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Int("status", lrw.Status()).Int("size", lrw.Size()).Str("duration", duration.String()).Msg("Received Request")
log.Warn().
Str("host", r.Host).
Str("uri", r.RequestURI).
Str("method", r.Method).
Int("status", lrw.Status()).
Int("size", lrw.Size()).
Str("duration", duration.String()).
Str("uuid", uuid).
Msg("Slow Request")
}
})
}

View File

@@ -6,10 +6,12 @@ import (
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/http/httptrace"
"net/textproto"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
"github.com/rs/zerolog/log"
@@ -52,7 +54,7 @@ func (router *Router) roundRobin(host *Host) {
}
}
func rewriteRequestURL(r *http.Request, host *Host, remote string, ) error {
func rewriteRequestURL(r *http.Request, host *Host, remote string) error {
subUrlPath := r.URL.RequestURI()
var uri string
if host.Secure {
@@ -76,14 +78,11 @@ func rewriteRequestURL(r *http.Request, host *Host, remote string, ) error {
}
func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If trace enabled dump incoming request, could break request so exit early if that happens
if !dumpRequest(w, r) {
return
}
host, ok := router.domains.Get(r.Host)
transport := http.DefaultTransport
portLessHost, _ := strings.CutSuffix(r.Host, fmt.Sprintf(":%d", router.config.Server.Port))
host, ok := router.domains.Get(portLessHost)
if !ok {
log.Warn().Str("host", r.Host).Msg("Could not find Host")
log.Warn().Str("host", portLessHost).Msg("Could not find Host")
w.WriteHeader(http.StatusOK)
return
}
@@ -102,6 +101,7 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if outreq.Body != nil {
defer outreq.Body.Close()
}
outreq.Close = false
reqUpType := upgradeType(outreq.Header)
if !isPrintableAscii(reqUpType) {
@@ -117,7 +117,7 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
log.Debug().Str("upgrade", reqUpType).Msg("Request upgrade")
log.Trace().Str("upgrade_type", reqUpType).Msg("Found upgrade Type")
}
stripClientProvidedXForwardHeaders(outreq.Header)
@@ -133,20 +133,34 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Dump created request
if !dumpRequest(w, outreq) {
return
}
var (
roundTripMutex sync.Mutex
roundTripDone bool
)
res, err := router.client.Do(outreq)
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
roundTripMutex.Lock()
defer roundTripMutex.Unlock()
if roundTripDone {
return nil
}
h := w.Header()
copyHeader(h, http.Header(header))
w.WriteHeader(code)
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMutex.Lock()
roundTripDone = true
roundTripMutex.Unlock()
if err != nil {
log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request")
w.WriteHeader(http.StatusInternalServerError)
return
}
// If trace enabled dump response
if !dumpResponse(w, res) {
log.Error().Err(err).Any("out_request", outreq).Msg("Could not complete transport round trip")
return
}
@@ -266,7 +280,6 @@ func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
if _, err := io.Copy(c.user, c.backend); err != nil {
errc <- err
return
}
// backend conn has reached EOF so propogate close write to user conn
@@ -316,36 +329,15 @@ func isPrintableAscii(reqUpType string) bool {
}
func upgradeType(header http.Header) string {
for _, val := range header.Values("Connection") {
if strings.ToLower(val) == "upgrade" {
return header.Get("Upgrade")
// Iterate over Connection headers if those exist multiple times
for _, conVal := range header.Values("Connection") {
for _, headerVal := range strings.Split(conVal, ",") {
trimmed := strings.TrimSpace(headerVal)
if strings.EqualFold(trimmed, "upgrade") {
upType := header.Get("Upgrade")
return upType
}
}
}
return ""
}
func dumpRequest(w http.ResponseWriter, r *http.Request) bool {
if e := log.Trace(); e.Enabled() {
rDump, err := httputil.DumpRequest(r, true)
if err != nil {
log.Error().Err(err).Msg("Could not dump request")
w.WriteHeader(http.StatusInternalServerError)
return false
}
log.Trace().Str("dump", string(rDump)).Msg("Dumping Request")
}
return true
}
func dumpResponse(w http.ResponseWriter, r *http.Response) bool {
if e := log.Trace(); e.Enabled() {
dump, err := httputil.DumpResponse(r, true)
if err != nil {
log.Error().Err(err).Msg("Could not dump response")
w.WriteHeader(http.StatusInternalServerError)
return false
}
log.Trace().Str("dump", string(dump)).Msg("Dumping Response")
}
return true
}