From a98b68177c4c1c72a38c59ee3ab3647fbb0b16d8 Mon Sep 17 00:00:00 2001 From: Pablu23 Date: Wed, 6 Nov 2024 10:21:30 +0100 Subject: [PATCH] Add better configuration, through yaml file --- Makefile | 2 +- cmd/domain-router/main.go | 116 +++++++++++++++++++++----------------- config.go | 31 ++++++++++ config.yaml | 32 +++++++++++ domains.conf | 3 - go.mod | 1 + go.sum | 3 + middleware/rate-limit.go | 2 +- router.go | 57 ++++++++++++++++++- 9 files changed, 188 insertions(+), 59 deletions(-) create mode 100644 config.go create mode 100644 config.yaml delete mode 100644 domains.conf diff --git a/Makefile b/Makefile index 015dda0..f8f8e21 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ run: build - sudo ./bin/domain-router --pretty --log-level debug + sudo ./bin/domain-router build: go build -o bin/domain-router cmd/domain-router/main.go diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index 1b4341b..4405a7a 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -1,16 +1,13 @@ package main import ( - "bufio" "crypto/tls" - "errors" "flag" "fmt" "io" "net/http" + "net/url" "os" - "strconv" - "strings" "time" domainrouter "github.com/pablu23/domain-router" @@ -18,55 +15,52 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" "gopkg.in/natefinch/lumberjack.v2" + "gopkg.in/yaml.v3" ) var ( - configFileFlag = flag.String("config", "domains.conf", "Path to Domain config file") - certFlag = flag.String("cert", "", "Path to cert file") - keyFlag = flag.String("key", "", "Path to key file") - portFlag = flag.Int("port", 80, "Port") - prettyLogsFlag = flag.Bool("pretty", false, "Pretty print? Default is json") - logPathFlag = flag.String("log", "", "Path to logfile, default is stderr") - logLevelFlag = flag.String("log-level", "info", "Log Level") + configFileFlag = flag.String("config", "config.yaml", "Path to config file") ) func main() { flag.Parse() - setupLogging() - - domains, err := loadConfig(*configFileFlag) + config, err := loadConfig(*configFileFlag) if err != nil { log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") } + setupLogging(config) client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - router := domainrouter.New(domains, client) + router := domainrouter.New(config, client) mux := http.NewServeMux() mux.HandleFunc("/", router.Route) - limiter := middleware.NewLimiter(10, 250, 30*time.Second, 1*time.Minute) - limiter.Start() + if config.General.AnnouncePublic { + h, err := url.JoinPath("/", config.General.HealthEndpoint) + if err != nil { + log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path") + h = "/healthz" + } + mux.HandleFunc(h, router.Healthz) + } - pipeline := middleware.Pipeline( - limiter.RateLimiter, - middleware.RequestLogger, - ) + pipeline := configureMiddleware(config) server := http.Server{ - Addr: fmt.Sprintf(":%d", *portFlag), + Addr: fmt.Sprintf(":%d", config.Server.Port), Handler: pipeline(mux), } - if *certFlag != "" && *keyFlag != "" { + if config.Server.CertFile != "" && config.Server.KeyFile != "" { server.TLSConfig = &tls.Config{ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(*certFlag, *keyFlag) + cert, err := tls.LoadX509KeyPair(config.Server.CertFile, config.Server.KeyFile) if err != nil { return nil, err } @@ -74,61 +68,79 @@ func main() { }, } - log.Info().Int("port", *portFlag).Str("cert", *certFlag).Str("key", *keyFlag).Msg("Starting server") + log.Info().Int("port", config.Server.Port).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Msg("Starting server") err := server.ListenAndServeTLS("", "") - log.Fatal().Err(err).Str("cert", *certFlag).Str("key", *keyFlag).Int("port", *portFlag).Msg("Could not start server") + log.Fatal().Err(err).Str("cert", config.Server.CertFile).Str("key", config.Server.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") } else { - log.Info().Int("port", *portFlag).Msg("Starting server") + log.Info().Int("port", config.Server.Port).Msg("Starting server") err := server.ListenAndServe() - log.Fatal().Err(err).Int("port", *portFlag).Msg("Could not start server") + log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") } } -func setupLogging() { - logLevel, err := zerolog.ParseLevel(*logLevelFlag) +func configureMiddleware(config *domainrouter.Config) middleware.Middleware { + middlewares := make([]middleware.Middleware, 0) + + if config.RateLimit.Enabled { + refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) + if err != nil { + log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker") + } + + cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker) + if err != nil { + log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") + } + limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) + limiter.Start() + middlewares = append(middlewares, limiter.RateLimiter) + } + + if config.Logging.Requests { + middlewares = append(middlewares, middleware.RequestLogger) + } + + pipeline := middleware.Pipeline(middlewares...) + return pipeline +} + +func setupLogging(config *domainrouter.Config) { + logLevel, err := zerolog.ParseLevel(config.Logging.Level) if err != nil { - log.Fatal().Err(err).Str("level", *logLevelFlag).Msg("Could not parse string to level") + log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level") } zerolog.SetGlobalLevel(logLevel) - if *prettyLogsFlag { + if config.Logging.Pretty { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } - if *logPathFlag != "" { + if config.Logging.Path != "" { var console io.Writer = os.Stderr - if *prettyLogsFlag { + if config.Logging.Pretty { console = zerolog.ConsoleWriter{Out: os.Stderr} } log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ - Filename: *logPathFlag, + Filename: config.Logging.Path, MaxAge: 14, MaxBackups: 10, })) } } -func loadConfig(path string) (map[string]int, error) { - file, err := os.Open(path) +func loadConfig(path string) (*domainrouter.Config, error) { + f, err := os.Open(path) if err != nil { return nil, err } - scanner := bufio.NewScanner(file) - scanner.Split(bufio.ScanLines) + defer f.Close() - m := make(map[string]int) - for scanner.Scan() { - line := scanner.Text() - params := strings.Split(line, ";") - if len(params) <= 1 { - return nil, errors.New("Line does not contain enough Parameters") - } - port, err := strconv.Atoi(params[1]) - if err != nil { - return nil, err - } - m[params[0]] = port + var cfg domainrouter.Config + decoder := yaml.NewDecoder(f) + err = decoder.Decode(&cfg) + if err != nil { + return nil, err } - return m, nil + return &cfg, err } diff --git a/config.go b/config.go new file mode 100644 index 0000000..4a85a9c --- /dev/null +++ b/config.go @@ -0,0 +1,31 @@ +package domainrouter + +type Config struct { + General struct { + AnnouncePublic bool `yaml:"announce"` + HealthEndpoint string `yaml:"healthz"` + } `yaml:"general"` + Server struct { + Port int `yaml:"port"` + CertFile string `yaml:"certFile"` + KeyFile string `yaml:"keyFile"` + } `yaml:"server"` + Hosts []struct { + Port int `yaml:"port"` + Domains []string `yaml:"domains"` + Public bool `yaml:"public"` + } `yaml:"hosts"` + RateLimit struct { + Enabled bool `yaml:"enabled"` + BucketSize int `yaml:"bucketSize"` + RefillTicker string `yaml:"refillTime"` + CleanupTicker string `yaml:"cleanupTime"` + BucketRefill int `yaml:"refillSize"` + } `yaml:"rateLimit"` + Logging struct { + Level string `yaml:"level"` + Pretty bool `yaml:"pretty"` + Path string `yaml:"path"` + Requests bool `yaml:"requests"` + } `yaml:"logging"` +} diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..15a0fd0 --- /dev/null +++ b/config.yaml @@ -0,0 +1,32 @@ +general: + announce: true + healthz: healthz + +server: + port: 443 + certFile: server.crt + keyFile: server.key + +rateLimit: + enabled: true + bucketSize: 50 + refillSize: 10 + refillTime: 1m + cleanupTime: 5m + +hosts: + - port: 8181 + domains: + - localhost + - test.localhost + - test2.localhost + public: true + - port: 8282 + domains: + - private.localhost + public: false + +logging: + level: debug + pretty: true + requests: true diff --git a/domains.conf b/domains.conf deleted file mode 100644 index 85c105f..0000000 --- a/domains.conf +++ /dev/null @@ -1,3 +0,0 @@ -test.localhost;8181 -test2.localhost;8282 -localhost;8080 diff --git a/go.mod b/go.mod index d351408..51a8f93 100644 --- a/go.mod +++ b/go.mod @@ -10,4 +10,5 @@ require ( github.com/urfave/negroni v1.0.0 golang.org/x/sys v0.12.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 + gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index f0da108..1d68d5e 100644 --- a/go.sum +++ b/go.sum @@ -15,5 +15,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index b4a197c..18f533c 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -82,7 +82,7 @@ func (l *Limiter) Manage() { } func (l *Limiter) RateLimiter(next http.Handler) http.Handler { - log.Info().Msg("Enabling Ratelimits") + log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { addr := strings.Split(r.RemoteAddr, ":")[0] l.rwLock.RLock() diff --git a/router.go b/router.go index e62b854..419dec0 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package domainrouter import ( + "encoding/json" "errors" "fmt" "io" @@ -12,17 +13,69 @@ import ( ) type Router struct { + config *Config domains *util.ImmutableMap[string, int] client *http.Client } -func New(domains map[string]int, client *http.Client) Router { +func New(config *Config, client *http.Client) Router { + m := make(map[string]int) + for _, host := range config.Hosts { + for _, domain := range host.Domains { + m[domain] = host.Port + } + } + return Router{ - domains: util.NewImmutableMap(domains), + config: config, + domains: util.NewImmutableMap(m), client: client, } } +func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { + if !router.config.General.AnnouncePublic { + http.NotFound(w, r) + return + } + + result := make([]struct { + Domain string + Healthy bool + }, 0) + + for _, host := range router.config.Hosts { + if !host.Public { + continue + } + + healthy := true + res, err := router.client.Get(fmt.Sprintf("http://localhost:%d/healthz", host.Port)) + if err != nil { + log.Warn().Err(err).Int("port", host.Port).Msg("Unhealthy") + healthy = false + } + + for _, domain := range host.Domains { + result = append(result, struct { + Domain string + Healthy bool + }{domain, healthy && res.StatusCode == 200}) + } + } + + data, err := json.Marshal(&result) + if err != nil { + log.Error().Err(err).Msg("Could not json encode Healthz") + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.Write(data) + w.WriteHeader(http.StatusOK) +} + func (router *Router) Route(w http.ResponseWriter, r *http.Request) { port, ok := router.domains.Get(r.Host) if !ok {