Split route function into more readable smaller functions

This commit is contained in:
Pablu23
2024-11-07 00:38:19 +01:00
parent 3bf932363e
commit 403c89b068

142
router.go
View File

@@ -14,27 +14,21 @@ import (
type Router struct { type Router struct {
config *Config config *Config
domains *util.ImmutableMap[string, struct { domains *util.ImmutableMap[string, Host]
Port int client *http.Client
Remote string }
Secure bool
}] type Host struct {
client *http.Client Port int
Remote string
Secure bool
} }
func New(config *Config, client *http.Client) Router { func New(config *Config, client *http.Client) Router {
m := make(map[string]struct { m := make(map[string]Host)
Port int
Remote string
Secure bool
})
for _, host := range config.Hosts { for _, host := range config.Hosts {
for _, domain := range host.Domains { for _, domain := range host.Domains {
m[domain] = struct { m[domain] = Host{host.Port, host.Remote, host.Secure}
Port int
Remote string
Secure bool
}{host.Port, host.Remote, host.Secure}
} }
} }
@@ -95,18 +89,7 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (router *Router) Route(w http.ResponseWriter, r *http.Request) { func createRequest(r *http.Request, host *Host) (*http.Request, error) {
host, ok := router.domains.Get(r.Host)
if !ok {
log.Warn().Str("host", r.Host).Msg("Could not find Host")
w.WriteHeader(http.StatusOK)
return
}
if !dumpRequest(w, r) {
return
}
subUrlPath := r.URL.RequestURI() subUrlPath := r.URL.RequestURI()
var url string var url string
if host.Secure { if host.Secure {
@@ -117,65 +100,108 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
req, err := http.NewRequest(r.Method, url, r.Body) req, err := http.NewRequest(r.Method, url, r.Body)
if err != nil { if err != nil {
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", host.Remote).Str("path", subUrlPath).Int("port", host.Port).Msg("Could not create request") return nil, err
w.WriteHeader(http.StatusInternalServerError)
return
} }
for name, values := range r.Header { copyRequestHeader(r, req)
for _, value := range values {
req.Header.Set(name, value)
}
}
req.Header.Set("X-Forwarded-For", r.RemoteAddr) req.Header.Set("X-Forwarded-For", r.RemoteAddr)
for _, cookie := range r.Cookies() { for _, cookie := range r.Cookies() {
req.AddCookie(cookie) req.AddCookie(cookie)
} }
return req, nil
}
func copyRequestHeader(origin *http.Request, destination *http.Request) {
for name, values := range origin.Header {
for _, value := range values {
destination.Header.Set(name, value)
}
}
}
func applyResponseHeader(w http.ResponseWriter, res *http.Response) {
for name, values := range res.Header {
for _, value := range values {
w.Header().Set(name, value)
}
}
}
func applyCookies(w http.ResponseWriter, res *http.Response) {
cookies := res.Cookies()
for _, cookie := range cookies {
http.SetCookie(w, cookie)
}
}
func applyBody(w http.ResponseWriter, res *http.Response) error {
body, err := io.ReadAll(res.Body)
defer res.Body.Close()
if err != nil {
return err
}
_, err = w.Write(body)
if err != nil {
return err
}
w.WriteHeader(res.StatusCode)
return nil
}
func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
// If trace enabled dump incoming request, could break request so exit early if that happens
if !dumpRequest(w, r) {
return
}
host, ok := router.domains.Get(r.Host)
if !ok {
log.Warn().Str("host", r.Host).Msg("Could not find Host")
w.WriteHeader(http.StatusOK)
return
}
req, err := createRequest(r, &host)
if err != nil {
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not create request")
w.WriteHeader(http.StatusInternalServerError)
return
}
// Dump created request
if !dumpRequest(w, req) { if !dumpRequest(w, req) {
return return
} }
res, err := router.client.Do(req) res, err := router.client.Do(req)
if err != nil { if err != nil {
log.Error().Err(err).Str("remote", host.Remote).Str("path", subUrlPath).Int("port", host.Port).Msg("Could not complete request") log.Error().Err(err).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not complete request")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
cookies := res.Cookies() // If trace enabled dump response
for _, cookie := range cookies {
http.SetCookie(w, cookie)
}
if !dumpResponse(w, res) { if !dumpResponse(w, res) {
return return
} }
applyCookies(w, res)
// Exit early because its a redirect // Exit early because its a redirect
// Maybe this should be before applying cookies or after applying headers
if !handleLocation(w, r, res) { if !handleLocation(w, r, res) {
return return
} }
for name, values := range res.Header { applyResponseHeader(w, res)
for _, value := range values {
w.Header().Set(name, value)
}
}
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body) err = applyBody(w, res)
defer res.Body.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Could not read body") log.Error().Err(err).Msg("Could not apply body")
w.WriteHeader(http.StatusInternalServerError)
return
}
_, err = w.Write(body)
if err != nil {
log.Error().Err(err).Msg("Could not write body")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }