Add more settings, cleanup code and add documentation

This commit is contained in:
Pablu23
2024-11-06 16:52:28 +01:00
parent a98b68177c
commit c0b711a992
4 changed files with 134 additions and 56 deletions

View File

@@ -57,10 +57,10 @@ func main() {
Handler: pipeline(mux), Handler: pipeline(mux),
} }
if config.Server.CertFile != "" && config.Server.KeyFile != "" { if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{ server.TLSConfig = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(config.Server.CertFile, config.Server.KeyFile) cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -68,9 +68,9 @@ func main() {
}, },
} }
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Msg("Starting server") log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "") err := server.ListenAndServeTLS("", "")
log.Fatal().Err(err).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
} else { } else {
log.Info().Int("port", config.Server.Port).Msg("Starting server") log.Info().Int("port", config.Server.Port).Msg("Starting server")
err := server.ListenAndServe() err := server.ListenAndServe()
@@ -111,19 +111,21 @@ func setupLogging(config *domainrouter.Config) {
} }
zerolog.SetGlobalLevel(logLevel) zerolog.SetGlobalLevel(logLevel)
log.Info().Str("level", config.Logging.Level).Msg("Set logging level")
if config.Logging.Pretty { if config.Logging.Pretty {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
} }
if config.Logging.Path != "" { if config.Logging.File.Enabled {
var console io.Writer = os.Stderr var console io.Writer = os.Stderr
if config.Logging.Pretty { if config.Logging.Pretty {
console = zerolog.ConsoleWriter{Out: os.Stderr} console = zerolog.ConsoleWriter{Out: os.Stderr}
} }
log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{
Filename: config.Logging.Path, Filename: config.Logging.File.Path,
MaxAge: 14, MaxAge: config.Logging.File.MaxAge,
MaxBackups: 10, MaxBackups: config.Logging.File.MaxBackups,
})) }))
} }
} }

View File

@@ -6,14 +6,19 @@ type Config struct {
HealthEndpoint string `yaml:"healthz"` HealthEndpoint string `yaml:"healthz"`
} `yaml:"general"` } `yaml:"general"`
Server struct { Server struct {
Port int `yaml:"port"` Port int `yaml:"port"`
CertFile string `yaml:"certFile"` Ssl struct {
KeyFile string `yaml:"keyFile"` Enabled bool `yaml:"enabled"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
} `yaml:"ssl"`
} `yaml:"server"` } `yaml:"server"`
Hosts []struct { Hosts []struct {
Port int `yaml:"port"` Port int `yaml:"port"`
Remote string `yaml:"remote"`
Domains []string `yaml:"domains"` Domains []string `yaml:"domains"`
Public bool `yaml:"public"` Public bool `yaml:"public"`
Secure bool `yaml:"secure"`
} `yaml:"hosts"` } `yaml:"hosts"`
RateLimit struct { RateLimit struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
@@ -25,7 +30,12 @@ type Config struct {
Logging struct { Logging struct {
Level string `yaml:"level"` Level string `yaml:"level"`
Pretty bool `yaml:"pretty"` Pretty bool `yaml:"pretty"`
Path string `yaml:"path"`
Requests bool `yaml:"requests"` Requests bool `yaml:"requests"`
File struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
MaxAge int `yaml:"maxAge"`
MaxBackups int `yamls:"maxBackups"`
} `yaml:"file"`
} `yaml:"logging"` } `yaml:"logging"`
} }

View File

@@ -1,32 +1,65 @@
general:
announce: true
healthz: healthz
server: server:
port: 443 port: 443
certFile: server.crt ssl:
keyFile: server.key enabled: true
certFile: server.crt
keyFile: server.key
logging:
level: info
# Pretty print for human consumption otherwise json
pretty: true
# Log incoming requests
requests: true
# Log to file aswell as stderr
file:
enabled: false
maxAge: 14
maxBackups: 10
path: ~/logs/router
rateLimit: rateLimit:
enabled: true enabled: true
# How many requests per ip adress are allowed
bucketSize: 50 bucketSize: 50
# How many requests per ip address are refilled
refillSize: 10 refillSize: 10
# How often requests per ip address are refilled
refillTime: 1m refillTime: 1m
# How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up)
cleanupTime: 5m cleanupTime: 5m
hosts: hosts:
- port: 8181 # Remote address to request
- remote: localhost
# Port on which to request
port: 8181
# Health check if announce is true
public: true
# Domains which get redirected to host
domains: domains:
- localhost - localhost
- test.localhost - test.localhost
- test2.localhost
public: true - remote: localhost
- port: 8282 port: 8282
public: false
domains: domains:
- private.localhost - private.localhost
public: false
logging: - remote: www.google.com
level: debug port: 443
pretty: true public: false
requests: true # Uses https under the hood to communicate with the remote host
secure: true
domains:
- google.localhost
general:
# Expose health endpoint, that requests health endpoints from hosts which are public
announce: true
# Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint
healthz: healthz

View File

@@ -14,15 +14,27 @@ import (
type Router struct { type Router struct {
config *Config config *Config
domains *util.ImmutableMap[string, int] domains *util.ImmutableMap[string, struct {
client *http.Client Port int
Remote string
Secure bool
}]
client *http.Client
} }
func New(config *Config, client *http.Client) Router { func New(config *Config, client *http.Client) Router {
m := make(map[string]int) m := make(map[string]struct {
Port int
Remote string
Secure bool
})
for _, host := range config.Hosts { for _, host := range config.Hosts {
for _, domain := range host.Domains { for _, domain := range host.Domains {
m[domain] = host.Port m[domain] = struct {
Port int
Remote string
Secure bool
}{host.Port, host.Remote, host.Secure}
} }
} }
@@ -79,6 +91,7 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
func (router *Router) Route(w http.ResponseWriter, r *http.Request) { func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
port, ok := router.domains.Get(r.Host) port, ok := router.domains.Get(r.Host)
if !ok { if !ok {
log.Warn().Str("host", r.Host).Msg("Could not find Host")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
} }
@@ -88,9 +101,16 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
} }
subUrlPath := r.URL.RequestURI() subUrlPath := r.URL.RequestURI()
req, err := http.NewRequest(r.Method, fmt.Sprintf("http://localhost:%d%s", port, subUrlPath), r.Body) var url string
if port.Secure {
url = fmt.Sprintf("https://%s:%d%s", port.Remote, port.Port, subUrlPath)
} else {
url = fmt.Sprintf("http://%s:%d%s", port.Remote, port.Port, subUrlPath)
}
req, err := http.NewRequest(r.Method, url, r.Body)
if err != nil { if err != nil {
log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not create request") log.Error().Err(err).Str("remote", port.Remote).Str("path", subUrlPath).Int("port", port.Port).Msg("Could not create request")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@@ -112,7 +132,7 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
res, err := router.client.Do(req) res, err := router.client.Do(req)
if err != nil { if err != nil {
log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not complete request") log.Error().Err(err).Str("remote", port.Remote).Str("path", subUrlPath).Int("port", port.Port).Msg("Could not complete request")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@@ -126,31 +146,44 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
return return
} }
if loc, err := res.Location(); !errors.Is(err, http.ErrNoLocation) { // Exit early because its a redirect
http.Redirect(w, r, loc.RequestURI(), http.StatusFound) if !handleLocation(w, r, res) {
} else { return
for name, values := range res.Header { }
for _, value := range values {
w.Header().Set(name, value)
}
}
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body) for name, values := range res.Header {
defer res.Body.Close() for _, value := range values {
if err != nil { w.Header().Set(name, value)
log.Error().Err(err).Msg("Could not read body")
w.WriteHeader(http.StatusInternalServerError)
return
}
_, err = w.Write(body)
if err != nil {
log.Error().Err(err).Msg("Could not write body")
w.WriteHeader(http.StatusInternalServerError)
return
} }
} }
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body)
defer res.Body.Close()
if err != nil {
log.Error().Err(err).Msg("Could not read body")
w.WriteHeader(http.StatusInternalServerError)
return
}
_, err = w.Write(body)
if err != nil {
log.Error().Err(err).Msg("Could not write body")
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func handleLocation(w http.ResponseWriter, r *http.Request, res *http.Response) bool {
if loc, err := res.Location(); err == nil {
http.Redirect(w, r, loc.RequestURI(), http.StatusFound)
return false
} else if !errors.Is(err, http.ErrNoLocation) {
log.Error().Err(err).Msg("Could not extract location")
w.WriteHeader(http.StatusInternalServerError)
return false
}
return true
} }
func dumpRequest(w http.ResponseWriter, r *http.Request) bool { func dumpRequest(w http.ResponseWriter, r *http.Request) bool {