Add graceful stopping to server, and extend middleware and pipeline logic

This commit is contained in:
Pablu23
2025-10-01 14:36:30 +02:00
parent 66f2811fff
commit 572e1177ef
5 changed files with 113 additions and 25 deletions

View File

@@ -1,12 +1,17 @@
package main
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
domainrouter "github.com/pablu23/domain-router"
@@ -37,17 +42,32 @@ func main() {
},
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
router := domainrouter.New(config, client)
mux := http.NewServeMux()
mux.HandleFunc("/", router.ServeHTTP)
pipeline := configureMiddleware(config)
pipeline.Manage()
server := http.Server{
Addr: fmt.Sprintf(":%d", config.Server.Port),
Handler: pipeline(mux),
Addr: fmt.Sprintf(":%d", config.Server.Port),
// this is rather bad looking
Handler: pipeline.Use()(mux),
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
<-sigs
log.Info().Msg("Stopping server")
server.Shutdown(context.Background())
pipeline.Stop()
wg.Done()
}()
if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -71,16 +91,23 @@ func main() {
}
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "")
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
}
} else {
log.Info().Int("port", config.Server.Port).Msg("Starting server")
err := server.ListenAndServe()
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
}
}
wg.Wait()
log.Info().Msg("Server shutdown completly, have a nice day")
}
func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
middlewares := make([]middleware.Middleware, 0)
func configureMiddleware(config *domainrouter.Config) *middleware.Pipeline {
pipeline := middleware.NewPipeline()
if config.RateLimit.Enabled {
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
@@ -93,19 +120,16 @@ func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
}
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
limiter.Start()
middlewares = append(middlewares, limiter.RateLimiter)
pipeline.AddMiddleware(limiter)
}
if config.Logging.Requests {
middlewares = append(middlewares, middleware.RequestLogger)
pipeline.AddMiddleware(&middleware.RequestLogger{})
}
metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json")
go metrics.Manage()
middlewares = append(middlewares, metrics.RequestMetrics)
pipeline.AddMiddleware(metrics)
pipeline := middleware.Pipeline(middlewares...)
return pipeline
}

View File

@@ -9,7 +9,16 @@ import (
"github.com/urfave/negroni"
)
func RequestLogger(next http.Handler) http.Handler {
type RequestLogger struct{}
func (_ *RequestLogger) Stop() {
log.Info().Msg("Stopped Logging")
}
func (_ *RequestLogger) Manage() {
}
func (_ *RequestLogger) Use(next http.Handler) http.Handler {
log.Info().Msg("Enabling Logging")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()

View File

@@ -17,6 +17,7 @@ type Metrics struct {
endpointMetrics []EndpointMetrics
ticker *time.Ticker
file string
stop chan bool
}
type EndpointMetrics struct {
@@ -44,7 +45,7 @@ func NewMetrics(bufferSize int, flushTimeout time.Duration, file string) *Metric
}
}
func (m *Metrics) RequestMetrics(next http.Handler) http.Handler {
func (m *Metrics) Use(next http.Handler) http.Handler {
log.Info().Msg("Enabling Request Metrics")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
@@ -71,6 +72,8 @@ func (m *Metrics) Manage() {
m.calculateDuration(rm)
case <-m.ticker.C:
m.Flush()
case <-m.stop:
return
}
}
}
@@ -129,3 +132,13 @@ func (m *Metrics) Flush() {
log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush")
}
func (m *Metrics) Stop() {
log.Info().Msg("Stopping Request Metrics")
for len(m.c) > 0 {
rm := <- m.c
m.calculateDuration(rm)
}
m.Flush()
log.Info().Msg("Stopped Request Metrics")
}

View File

@@ -5,14 +5,52 @@ import (
"slices"
)
type Middleware func(http.Handler) http.Handler
type Middleware interface {
Use(http.Handler) http.Handler
Manage()
Stop()
}
func Pipeline(funcs ...Middleware) Middleware {
type Pipeline struct {
middleware []Middleware
}
func NewPipeline() *Pipeline {
return &Pipeline{}
}
func (p *Pipeline) AddMiddleware(m Middleware) {
p.middleware = append(p.middleware, m)
}
func (p *Pipeline) Use() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
for _, m := range slices.Backward(funcs) {
next = m(next)
for _, m := range slices.Backward(p.middleware) {
next = m.Use(next)
}
return next
}
}
func (p *Pipeline) Stop() {
for _, m := range p.middleware {
m.Stop()
}
}
func (p *Pipeline) Manage() {
for _, m := range p.middleware {
go m.Manage()
}
}
// func Pipeline(funcs ...Middleware) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// for _, m := range slices.Backward(funcs) {
// next = m.Use(next)
// }
//
// return next
// }
// }

View File

@@ -18,10 +18,11 @@ type Limiter struct {
bucketRefill int
rwLock *sync.RWMutex
rateChannel chan string
stop chan struct{}
}
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
return Limiter{
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) *Limiter {
return &Limiter{
currentBuckets: make(map[string]*atomic.Int64),
bucketSize: maxRequests,
refillTicker: time.NewTicker(refillInterval),
@@ -32,14 +33,15 @@ func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, clea
}
}
func (l *Limiter) Start() {
go l.Manage()
}
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
l.cleanupTicker.Reset(new)
}
func (l *Limiter) Stop() {
l.stop <- struct{}{}
log.Info().Msg("Stopped Ratelimits")
}
func (l *Limiter) Manage() {
for {
select {
@@ -77,6 +79,8 @@ func (l *Limiter) Manage() {
l.rwLock.Unlock()
duration := time.Since(start)
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
case <- l.stop:
return
}
}
}
@@ -93,7 +97,7 @@ func (l *Limiter) AddIfExists(ip string) bool {
return false
}
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
func (l *Limiter) Use(next http.Handler) http.Handler {
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
addr := strings.Split(r.RemoteAddr, ":")[0]