From 572e1177efef758cebd0e2407d72be738b62933b Mon Sep 17 00:00:00 2001 From: Pablu23 Date: Wed, 1 Oct 2025 14:36:30 +0200 Subject: [PATCH] Add graceful stopping to server, and extend middleware and pipeline logic --- cmd/domain-router/main.go | 48 +++++++++++++++++++++++++++++---------- middleware/logging.go | 11 ++++++++- middleware/metrics.go | 15 +++++++++++- middleware/pipeline.go | 46 +++++++++++++++++++++++++++++++++---- middleware/rate-limit.go | 18 +++++++++------ 5 files changed, 113 insertions(+), 25 deletions(-) diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index 8cdc6b1..5fbd1f8 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -1,12 +1,17 @@ package main import ( + "context" "crypto/tls" + "errors" "flag" "fmt" "io" "net/http" "os" + "os/signal" + "sync" + "syscall" "time" domainrouter "github.com/pablu23/domain-router" @@ -37,17 +42,32 @@ func main() { }, } + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + router := domainrouter.New(config, client) mux := http.NewServeMux() mux.HandleFunc("/", router.ServeHTTP) pipeline := configureMiddleware(config) + pipeline.Manage() server := http.Server{ - Addr: fmt.Sprintf(":%d", config.Server.Port), - Handler: pipeline(mux), + Addr: fmt.Sprintf(":%d", config.Server.Port), + // this is rather bad looking + Handler: pipeline.Use()(mux), } + var wg sync.WaitGroup + wg.Add(1) + go func() { + <-sigs + log.Info().Msg("Stopping server") + server.Shutdown(context.Background()) + pipeline.Stop() + wg.Done() + }() + if config.Server.Ssl.Enabled { server.TLSConfig = &tls.Config{ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -71,16 +91,23 @@ func main() { } log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server") err := server.ListenAndServeTLS("", "") - log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") + } } else { log.Info().Int("port", config.Server.Port).Msg("Starting server") err := server.ListenAndServe() - log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") + } } + + wg.Wait() + log.Info().Msg("Server shutdown completly, have a nice day") } -func configureMiddleware(config *domainrouter.Config) middleware.Middleware { - middlewares := make([]middleware.Middleware, 0) +func configureMiddleware(config *domainrouter.Config) *middleware.Pipeline { + pipeline := middleware.NewPipeline() if config.RateLimit.Enabled { refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) @@ -93,19 +120,16 @@ func configureMiddleware(config *domainrouter.Config) middleware.Middleware { log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") } limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) - limiter.Start() - middlewares = append(middlewares, limiter.RateLimiter) + pipeline.AddMiddleware(limiter) } if config.Logging.Requests { - middlewares = append(middlewares, middleware.RequestLogger) + pipeline.AddMiddleware(&middleware.RequestLogger{}) } metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json") - go metrics.Manage() - middlewares = append(middlewares, metrics.RequestMetrics) + pipeline.AddMiddleware(metrics) - pipeline := middleware.Pipeline(middlewares...) return pipeline } diff --git a/middleware/logging.go b/middleware/logging.go index b580866..8b20be6 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -9,7 +9,16 @@ import ( "github.com/urfave/negroni" ) -func RequestLogger(next http.Handler) http.Handler { +type RequestLogger struct{} + +func (_ *RequestLogger) Stop() { + log.Info().Msg("Stopped Logging") +} + +func (_ *RequestLogger) Manage() { +} + +func (_ *RequestLogger) Use(next http.Handler) http.Handler { log.Info().Msg("Enabling Logging") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/middleware/metrics.go b/middleware/metrics.go index cf13a08..98f7791 100644 --- a/middleware/metrics.go +++ b/middleware/metrics.go @@ -17,6 +17,7 @@ type Metrics struct { endpointMetrics []EndpointMetrics ticker *time.Ticker file string + stop chan bool } type EndpointMetrics struct { @@ -44,7 +45,7 @@ func NewMetrics(bufferSize int, flushTimeout time.Duration, file string) *Metric } } -func (m *Metrics) RequestMetrics(next http.Handler) http.Handler { +func (m *Metrics) Use(next http.Handler) http.Handler { log.Info().Msg("Enabling Request Metrics") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -71,6 +72,8 @@ func (m *Metrics) Manage() { m.calculateDuration(rm) case <-m.ticker.C: m.Flush() + case <-m.stop: + return } } } @@ -129,3 +132,13 @@ func (m *Metrics) Flush() { log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush") } + +func (m *Metrics) Stop() { + log.Info().Msg("Stopping Request Metrics") + for len(m.c) > 0 { + rm := <- m.c + m.calculateDuration(rm) + } + m.Flush() + log.Info().Msg("Stopped Request Metrics") +} diff --git a/middleware/pipeline.go b/middleware/pipeline.go index f95a2d8..2513114 100644 --- a/middleware/pipeline.go +++ b/middleware/pipeline.go @@ -5,14 +5,52 @@ import ( "slices" ) -type Middleware func(http.Handler) http.Handler +type Middleware interface { + Use(http.Handler) http.Handler + Manage() + Stop() +} -func Pipeline(funcs ...Middleware) Middleware { +type Pipeline struct { + middleware []Middleware +} + +func NewPipeline() *Pipeline { + return &Pipeline{} +} + +func (p *Pipeline) AddMiddleware(m Middleware) { + p.middleware = append(p.middleware, m) +} + +func (p *Pipeline) Use() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - for _, m := range slices.Backward(funcs) { - next = m(next) + for _, m := range slices.Backward(p.middleware) { + next = m.Use(next) } return next } } + +func (p *Pipeline) Stop() { + for _, m := range p.middleware { + m.Stop() + } +} + +func (p *Pipeline) Manage() { + for _, m := range p.middleware { + go m.Manage() + } +} + +// func Pipeline(funcs ...Middleware) func(http.Handler) http.Handler { +// return func(next http.Handler) http.Handler { +// for _, m := range slices.Backward(funcs) { +// next = m.Use(next) +// } +// +// return next +// } +// } diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index ac0e471..c5dfcfa 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -18,10 +18,11 @@ type Limiter struct { bucketRefill int rwLock *sync.RWMutex rateChannel chan string + stop chan struct{} } -func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { - return Limiter{ +func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) *Limiter { + return &Limiter{ currentBuckets: make(map[string]*atomic.Int64), bucketSize: maxRequests, refillTicker: time.NewTicker(refillInterval), @@ -32,14 +33,15 @@ func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, clea } } -func (l *Limiter) Start() { - go l.Manage() -} - func (l *Limiter) UpdateCleanupTime(new time.Duration) { l.cleanupTicker.Reset(new) } +func (l *Limiter) Stop() { + l.stop <- struct{}{} + log.Info().Msg("Stopped Ratelimits") +} + func (l *Limiter) Manage() { for { select { @@ -77,6 +79,8 @@ func (l *Limiter) Manage() { l.rwLock.Unlock() duration := time.Since(start) log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") + case <- l.stop: + return } } } @@ -93,7 +97,7 @@ func (l *Limiter) AddIfExists(ip string) bool { return false } -func (l *Limiter) RateLimiter(next http.Handler) http.Handler { +func (l *Limiter) Use(next http.Handler) http.Handler { log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { addr := strings.Split(r.RemoteAddr, ":")[0]