From c0b711a9924b50a96219c869b52330ea6dc1cfcc Mon Sep 17 00:00:00 2001 From: Pablu23 Date: Wed, 6 Nov 2024 16:52:28 +0100 Subject: [PATCH] Add more settings, cleanup code and add documentation --- cmd/domain-router/main.go | 18 ++++---- config.go | 18 ++++++-- config.yaml | 63 ++++++++++++++++++++------- router.go | 91 ++++++++++++++++++++++++++------------- 4 files changed, 134 insertions(+), 56 deletions(-) diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index 4405a7a..540cc14 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -57,10 +57,10 @@ func main() { Handler: pipeline(mux), } - if config.Server.CertFile != "" && config.Server.KeyFile != "" { + if config.Server.Ssl.Enabled { server.TLSConfig = &tls.Config{ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(config.Server.CertFile, config.Server.KeyFile) + cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile) if err != nil { return nil, err } @@ -68,9 +68,9 @@ func main() { }, } - log.Info().Int("port", config.Server.Port).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Msg("Starting server") + log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server") err := server.ListenAndServeTLS("", "") - log.Fatal().Err(err).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") + log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") } else { log.Info().Int("port", config.Server.Port).Msg("Starting server") err := server.ListenAndServe() @@ -111,19 +111,21 @@ func setupLogging(config *domainrouter.Config) { } zerolog.SetGlobalLevel(logLevel) + log.Info().Str("level", config.Logging.Level).Msg("Set logging level") if config.Logging.Pretty { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } - if config.Logging.Path != "" { + if config.Logging.File.Enabled { var console io.Writer = os.Stderr if config.Logging.Pretty { console = zerolog.ConsoleWriter{Out: os.Stderr} } + log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ - Filename: config.Logging.Path, - MaxAge: 14, - MaxBackups: 10, + Filename: config.Logging.File.Path, + MaxAge: config.Logging.File.MaxAge, + MaxBackups: config.Logging.File.MaxBackups, })) } } diff --git a/config.go b/config.go index 4a85a9c..14374a5 100644 --- a/config.go +++ b/config.go @@ -6,14 +6,19 @@ type Config struct { HealthEndpoint string `yaml:"healthz"` } `yaml:"general"` Server struct { - Port int `yaml:"port"` - CertFile string `yaml:"certFile"` - KeyFile string `yaml:"keyFile"` + Port int `yaml:"port"` + Ssl struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"certFile"` + KeyFile string `yaml:"keyFile"` + } `yaml:"ssl"` } `yaml:"server"` Hosts []struct { Port int `yaml:"port"` + Remote string `yaml:"remote"` Domains []string `yaml:"domains"` Public bool `yaml:"public"` + Secure bool `yaml:"secure"` } `yaml:"hosts"` RateLimit struct { Enabled bool `yaml:"enabled"` @@ -25,7 +30,12 @@ type Config struct { Logging struct { Level string `yaml:"level"` Pretty bool `yaml:"pretty"` - Path string `yaml:"path"` Requests bool `yaml:"requests"` + File struct { + Enabled bool `yaml:"enabled"` + Path string `yaml:"path"` + MaxAge int `yaml:"maxAge"` + MaxBackups int `yamls:"maxBackups"` + } `yaml:"file"` } `yaml:"logging"` } diff --git a/config.yaml b/config.yaml index 15a0fd0..0dca340 100644 --- a/config.yaml +++ b/config.yaml @@ -1,32 +1,65 @@ -general: - announce: true - healthz: healthz - server: port: 443 - certFile: server.crt - keyFile: server.key + ssl: + enabled: true + certFile: server.crt + keyFile: server.key + + +logging: + level: info + # Pretty print for human consumption otherwise json + pretty: true + # Log incoming requests + requests: true + # Log to file aswell as stderr + file: + enabled: false + maxAge: 14 + maxBackups: 10 + path: ~/logs/router + rateLimit: enabled: true + # How many requests per ip adress are allowed bucketSize: 50 + # How many requests per ip address are refilled refillSize: 10 + # How often requests per ip address are refilled refillTime: 1m + # How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up) cleanupTime: 5m + hosts: - - port: 8181 + # Remote address to request + - remote: localhost + # Port on which to request + port: 8181 + # Health check if announce is true + public: true + # Domains which get redirected to host domains: - localhost - test.localhost - - test2.localhost - public: true - - port: 8282 + + - remote: localhost + port: 8282 + public: false domains: - private.localhost - public: false -logging: - level: debug - pretty: true - requests: true + - remote: www.google.com + port: 443 + public: false + # Uses https under the hood to communicate with the remote host + secure: true + domains: + - google.localhost + +general: + # Expose health endpoint, that requests health endpoints from hosts which are public + announce: true + # Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint + healthz: healthz diff --git a/router.go b/router.go index 419dec0..84ceaf8 100644 --- a/router.go +++ b/router.go @@ -14,15 +14,27 @@ import ( type Router struct { config *Config - domains *util.ImmutableMap[string, int] - client *http.Client + domains *util.ImmutableMap[string, struct { + Port int + Remote string + Secure bool + }] + client *http.Client } func New(config *Config, client *http.Client) Router { - m := make(map[string]int) + m := make(map[string]struct { + Port int + Remote string + Secure bool + }) for _, host := range config.Hosts { for _, domain := range host.Domains { - m[domain] = host.Port + m[domain] = struct { + Port int + Remote string + Secure bool + }{host.Port, host.Remote, host.Secure} } } @@ -79,6 +91,7 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { func (router *Router) Route(w http.ResponseWriter, r *http.Request) { port, ok := router.domains.Get(r.Host) if !ok { + log.Warn().Str("host", r.Host).Msg("Could not find Host") w.WriteHeader(http.StatusOK) return } @@ -88,9 +101,16 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { } subUrlPath := r.URL.RequestURI() - req, err := http.NewRequest(r.Method, fmt.Sprintf("http://localhost:%d%s", port, subUrlPath), r.Body) + var url string + if port.Secure { + url = fmt.Sprintf("https://%s:%d%s", port.Remote, port.Port, subUrlPath) + } else { + url = fmt.Sprintf("http://%s:%d%s", port.Remote, port.Port, subUrlPath) + } + + req, err := http.NewRequest(r.Method, url, r.Body) if err != nil { - log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not create request") + log.Error().Err(err).Str("remote", port.Remote).Str("path", subUrlPath).Int("port", port.Port).Msg("Could not create request") w.WriteHeader(http.StatusInternalServerError) return } @@ -112,7 +132,7 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { res, err := router.client.Do(req) if err != nil { - log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not complete request") + log.Error().Err(err).Str("remote", port.Remote).Str("path", subUrlPath).Int("port", port.Port).Msg("Could not complete request") w.WriteHeader(http.StatusInternalServerError) return } @@ -126,31 +146,44 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { return } - if loc, err := res.Location(); !errors.Is(err, http.ErrNoLocation) { - http.Redirect(w, r, loc.RequestURI(), http.StatusFound) - } else { - for name, values := range res.Header { - for _, value := range values { - w.Header().Set(name, value) - } - } - w.WriteHeader(res.StatusCode) + // Exit early because its a redirect + if !handleLocation(w, r, res) { + return + } - body, err := io.ReadAll(res.Body) - defer res.Body.Close() - if err != nil { - log.Error().Err(err).Msg("Could not read body") - w.WriteHeader(http.StatusInternalServerError) - return - } - - _, err = w.Write(body) - if err != nil { - log.Error().Err(err).Msg("Could not write body") - w.WriteHeader(http.StatusInternalServerError) - return + for name, values := range res.Header { + for _, value := range values { + w.Header().Set(name, value) } } + w.WriteHeader(res.StatusCode) + + body, err := io.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + log.Error().Err(err).Msg("Could not read body") + w.WriteHeader(http.StatusInternalServerError) + return + } + + _, err = w.Write(body) + if err != nil { + log.Error().Err(err).Msg("Could not write body") + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func handleLocation(w http.ResponseWriter, r *http.Request, res *http.Response) bool { + if loc, err := res.Location(); err == nil { + http.Redirect(w, r, loc.RequestURI(), http.StatusFound) + return false + } else if !errors.Is(err, http.ErrNoLocation) { + log.Error().Err(err).Msg("Could not extract location") + w.WriteHeader(http.StatusInternalServerError) + return false + } + return true } func dumpRequest(w http.ResponseWriter, r *http.Request) bool {