Add graceful stopping to server, and extend middleware and pipeline logic

This commit is contained in:
Pablu23
2025-10-01 14:36:30 +02:00
parent 66f2811fff
commit 572e1177ef
5 changed files with 113 additions and 25 deletions

View File

@@ -1,12 +1,17 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
"os/signal"
"sync"
"syscall"
"time" "time"
domainrouter "github.com/pablu23/domain-router" domainrouter "github.com/pablu23/domain-router"
@@ -37,17 +42,32 @@ func main() {
}, },
} }
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
router := domainrouter.New(config, client) router := domainrouter.New(config, client)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", router.ServeHTTP) mux.HandleFunc("/", router.ServeHTTP)
pipeline := configureMiddleware(config) pipeline := configureMiddleware(config)
pipeline.Manage()
server := http.Server{ server := http.Server{
Addr: fmt.Sprintf(":%d", config.Server.Port), Addr: fmt.Sprintf(":%d", config.Server.Port),
Handler: pipeline(mux), // this is rather bad looking
Handler: pipeline.Use()(mux),
} }
var wg sync.WaitGroup
wg.Add(1)
go func() {
<-sigs
log.Info().Msg("Stopping server")
server.Shutdown(context.Background())
pipeline.Stop()
wg.Done()
}()
if config.Server.Ssl.Enabled { if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{ server.TLSConfig = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -71,16 +91,23 @@ func main() {
} }
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server") log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "") err := server.ListenAndServeTLS("", "")
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
}
} else { } else {
log.Info().Int("port", config.Server.Port).Msg("Starting server") log.Info().Int("port", config.Server.Port).Msg("Starting server")
err := server.ListenAndServe() err := server.ListenAndServe()
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
}
} }
wg.Wait()
log.Info().Msg("Server shutdown completly, have a nice day")
} }
func configureMiddleware(config *domainrouter.Config) middleware.Middleware { func configureMiddleware(config *domainrouter.Config) *middleware.Pipeline {
middlewares := make([]middleware.Middleware, 0) pipeline := middleware.NewPipeline()
if config.RateLimit.Enabled { if config.RateLimit.Enabled {
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
@@ -93,19 +120,16 @@ func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
} }
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
limiter.Start() pipeline.AddMiddleware(limiter)
middlewares = append(middlewares, limiter.RateLimiter)
} }
if config.Logging.Requests { if config.Logging.Requests {
middlewares = append(middlewares, middleware.RequestLogger) pipeline.AddMiddleware(&middleware.RequestLogger{})
} }
metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json") metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json")
go metrics.Manage() pipeline.AddMiddleware(metrics)
middlewares = append(middlewares, metrics.RequestMetrics)
pipeline := middleware.Pipeline(middlewares...)
return pipeline return pipeline
} }

View File

@@ -9,7 +9,16 @@ import (
"github.com/urfave/negroni" "github.com/urfave/negroni"
) )
func RequestLogger(next http.Handler) http.Handler { type RequestLogger struct{}
func (_ *RequestLogger) Stop() {
log.Info().Msg("Stopped Logging")
}
func (_ *RequestLogger) Manage() {
}
func (_ *RequestLogger) Use(next http.Handler) http.Handler {
log.Info().Msg("Enabling Logging") log.Info().Msg("Enabling Logging")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()

View File

@@ -17,6 +17,7 @@ type Metrics struct {
endpointMetrics []EndpointMetrics endpointMetrics []EndpointMetrics
ticker *time.Ticker ticker *time.Ticker
file string file string
stop chan bool
} }
type EndpointMetrics struct { type EndpointMetrics struct {
@@ -44,7 +45,7 @@ func NewMetrics(bufferSize int, flushTimeout time.Duration, file string) *Metric
} }
} }
func (m *Metrics) RequestMetrics(next http.Handler) http.Handler { func (m *Metrics) Use(next http.Handler) http.Handler {
log.Info().Msg("Enabling Request Metrics") log.Info().Msg("Enabling Request Metrics")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
@@ -71,6 +72,8 @@ func (m *Metrics) Manage() {
m.calculateDuration(rm) m.calculateDuration(rm)
case <-m.ticker.C: case <-m.ticker.C:
m.Flush() m.Flush()
case <-m.stop:
return
} }
} }
} }
@@ -129,3 +132,13 @@ func (m *Metrics) Flush() {
log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush") log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush")
} }
func (m *Metrics) Stop() {
log.Info().Msg("Stopping Request Metrics")
for len(m.c) > 0 {
rm := <- m.c
m.calculateDuration(rm)
}
m.Flush()
log.Info().Msg("Stopped Request Metrics")
}

View File

@@ -5,14 +5,52 @@ import (
"slices" "slices"
) )
type Middleware func(http.Handler) http.Handler type Middleware interface {
Use(http.Handler) http.Handler
Manage()
Stop()
}
func Pipeline(funcs ...Middleware) Middleware { type Pipeline struct {
middleware []Middleware
}
func NewPipeline() *Pipeline {
return &Pipeline{}
}
func (p *Pipeline) AddMiddleware(m Middleware) {
p.middleware = append(p.middleware, m)
}
func (p *Pipeline) Use() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
for _, m := range slices.Backward(funcs) { for _, m := range slices.Backward(p.middleware) {
next = m(next) next = m.Use(next)
} }
return next return next
} }
} }
func (p *Pipeline) Stop() {
for _, m := range p.middleware {
m.Stop()
}
}
func (p *Pipeline) Manage() {
for _, m := range p.middleware {
go m.Manage()
}
}
// func Pipeline(funcs ...Middleware) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// for _, m := range slices.Backward(funcs) {
// next = m.Use(next)
// }
//
// return next
// }
// }

View File

@@ -18,10 +18,11 @@ type Limiter struct {
bucketRefill int bucketRefill int
rwLock *sync.RWMutex rwLock *sync.RWMutex
rateChannel chan string rateChannel chan string
stop chan struct{}
} }
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) *Limiter {
return Limiter{ return &Limiter{
currentBuckets: make(map[string]*atomic.Int64), currentBuckets: make(map[string]*atomic.Int64),
bucketSize: maxRequests, bucketSize: maxRequests,
refillTicker: time.NewTicker(refillInterval), refillTicker: time.NewTicker(refillInterval),
@@ -32,14 +33,15 @@ func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, clea
} }
} }
func (l *Limiter) Start() {
go l.Manage()
}
func (l *Limiter) UpdateCleanupTime(new time.Duration) { func (l *Limiter) UpdateCleanupTime(new time.Duration) {
l.cleanupTicker.Reset(new) l.cleanupTicker.Reset(new)
} }
func (l *Limiter) Stop() {
l.stop <- struct{}{}
log.Info().Msg("Stopped Ratelimits")
}
func (l *Limiter) Manage() { func (l *Limiter) Manage() {
for { for {
select { select {
@@ -77,6 +79,8 @@ func (l *Limiter) Manage() {
l.rwLock.Unlock() l.rwLock.Unlock()
duration := time.Since(start) duration := time.Since(start)
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
case <- l.stop:
return
} }
} }
} }
@@ -93,7 +97,7 @@ func (l *Limiter) AddIfExists(ip string) bool {
return false return false
} }
func (l *Limiter) RateLimiter(next http.Handler) http.Handler { func (l *Limiter) Use(next http.Handler) http.Handler {
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
addr := strings.Split(r.RemoteAddr, ":")[0] addr := strings.Split(r.RemoteAddr, ":")[0]