Add graceful stopping to server, and extend middleware and pipeline logic
This commit is contained in:
@@ -1,12 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
domainrouter "github.com/pablu23/domain-router"
|
domainrouter "github.com/pablu23/domain-router"
|
||||||
@@ -37,17 +42,32 @@ func main() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigs := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
router := domainrouter.New(config, client)
|
router := domainrouter.New(config, client)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", router.ServeHTTP)
|
mux.HandleFunc("/", router.ServeHTTP)
|
||||||
|
|
||||||
pipeline := configureMiddleware(config)
|
pipeline := configureMiddleware(config)
|
||||||
|
|
||||||
|
pipeline.Manage()
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", config.Server.Port),
|
Addr: fmt.Sprintf(":%d", config.Server.Port),
|
||||||
Handler: pipeline(mux),
|
// this is rather bad looking
|
||||||
|
Handler: pipeline.Use()(mux),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
<-sigs
|
||||||
|
log.Info().Msg("Stopping server")
|
||||||
|
server.Shutdown(context.Background())
|
||||||
|
pipeline.Stop()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
if config.Server.Ssl.Enabled {
|
if config.Server.Ssl.Enabled {
|
||||||
server.TLSConfig = &tls.Config{
|
server.TLSConfig = &tls.Config{
|
||||||
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
@@ -71,16 +91,23 @@ func main() {
|
|||||||
}
|
}
|
||||||
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
|
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
|
||||||
err := server.ListenAndServeTLS("", "")
|
err := server.ListenAndServeTLS("", "")
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
|
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Info().Int("port", config.Server.Port).Msg("Starting server")
|
log.Info().Int("port", config.Server.Port).Msg("Starting server")
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
|
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
log.Info().Msg("Server shutdown completly, have a nice day")
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
|
func configureMiddleware(config *domainrouter.Config) *middleware.Pipeline {
|
||||||
middlewares := make([]middleware.Middleware, 0)
|
pipeline := middleware.NewPipeline()
|
||||||
|
|
||||||
if config.RateLimit.Enabled {
|
if config.RateLimit.Enabled {
|
||||||
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
|
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
|
||||||
@@ -93,19 +120,16 @@ func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
|
|||||||
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
|
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
|
||||||
}
|
}
|
||||||
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
|
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
|
||||||
limiter.Start()
|
pipeline.AddMiddleware(limiter)
|
||||||
middlewares = append(middlewares, limiter.RateLimiter)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Logging.Requests {
|
if config.Logging.Requests {
|
||||||
middlewares = append(middlewares, middleware.RequestLogger)
|
pipeline.AddMiddleware(&middleware.RequestLogger{})
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json")
|
metrics := middleware.NewMetrics(512, 1*time.Minute, "tmp_metrics.json")
|
||||||
go metrics.Manage()
|
pipeline.AddMiddleware(metrics)
|
||||||
middlewares = append(middlewares, metrics.RequestMetrics)
|
|
||||||
|
|
||||||
pipeline := middleware.Pipeline(middlewares...)
|
|
||||||
return pipeline
|
return pipeline
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,16 @@ import (
|
|||||||
"github.com/urfave/negroni"
|
"github.com/urfave/negroni"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RequestLogger(next http.Handler) http.Handler {
|
type RequestLogger struct{}
|
||||||
|
|
||||||
|
func (_ *RequestLogger) Stop() {
|
||||||
|
log.Info().Msg("Stopped Logging")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ *RequestLogger) Manage() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ *RequestLogger) Use(next http.Handler) http.Handler {
|
||||||
log.Info().Msg("Enabling Logging")
|
log.Info().Msg("Enabling Logging")
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type Metrics struct {
|
|||||||
endpointMetrics []EndpointMetrics
|
endpointMetrics []EndpointMetrics
|
||||||
ticker *time.Ticker
|
ticker *time.Ticker
|
||||||
file string
|
file string
|
||||||
|
stop chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type EndpointMetrics struct {
|
type EndpointMetrics struct {
|
||||||
@@ -44,7 +45,7 @@ func NewMetrics(bufferSize int, flushTimeout time.Duration, file string) *Metric
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Metrics) RequestMetrics(next http.Handler) http.Handler {
|
func (m *Metrics) Use(next http.Handler) http.Handler {
|
||||||
log.Info().Msg("Enabling Request Metrics")
|
log.Info().Msg("Enabling Request Metrics")
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -71,6 +72,8 @@ func (m *Metrics) Manage() {
|
|||||||
m.calculateDuration(rm)
|
m.calculateDuration(rm)
|
||||||
case <-m.ticker.C:
|
case <-m.ticker.C:
|
||||||
m.Flush()
|
m.Flush()
|
||||||
|
case <-m.stop:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -129,3 +132,13 @@ func (m *Metrics) Flush() {
|
|||||||
|
|
||||||
log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush")
|
log.Info().Str("file", m.file).Int("count", len(a)).Msg("Completed Metrics flush")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Metrics) Stop() {
|
||||||
|
log.Info().Msg("Stopping Request Metrics")
|
||||||
|
for len(m.c) > 0 {
|
||||||
|
rm := <- m.c
|
||||||
|
m.calculateDuration(rm)
|
||||||
|
}
|
||||||
|
m.Flush()
|
||||||
|
log.Info().Msg("Stopped Request Metrics")
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,14 +5,52 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Middleware func(http.Handler) http.Handler
|
type Middleware interface {
|
||||||
|
Use(http.Handler) http.Handler
|
||||||
|
Manage()
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func Pipeline(funcs ...Middleware) Middleware {
|
type Pipeline struct {
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPipeline() *Pipeline {
|
||||||
|
return &Pipeline{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipeline) AddMiddleware(m Middleware) {
|
||||||
|
p.middleware = append(p.middleware, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipeline) Use() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
for _, m := range slices.Backward(funcs) {
|
for _, m := range slices.Backward(p.middleware) {
|
||||||
next = m(next)
|
next = m.Use(next)
|
||||||
}
|
}
|
||||||
|
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Pipeline) Stop() {
|
||||||
|
for _, m := range p.middleware {
|
||||||
|
m.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipeline) Manage() {
|
||||||
|
for _, m := range p.middleware {
|
||||||
|
go m.Manage()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func Pipeline(funcs ...Middleware) func(http.Handler) http.Handler {
|
||||||
|
// return func(next http.Handler) http.Handler {
|
||||||
|
// for _, m := range slices.Backward(funcs) {
|
||||||
|
// next = m.Use(next)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// return next
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|||||||
@@ -18,10 +18,11 @@ type Limiter struct {
|
|||||||
bucketRefill int
|
bucketRefill int
|
||||||
rwLock *sync.RWMutex
|
rwLock *sync.RWMutex
|
||||||
rateChannel chan string
|
rateChannel chan string
|
||||||
|
stop chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
|
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) *Limiter {
|
||||||
return Limiter{
|
return &Limiter{
|
||||||
currentBuckets: make(map[string]*atomic.Int64),
|
currentBuckets: make(map[string]*atomic.Int64),
|
||||||
bucketSize: maxRequests,
|
bucketSize: maxRequests,
|
||||||
refillTicker: time.NewTicker(refillInterval),
|
refillTicker: time.NewTicker(refillInterval),
|
||||||
@@ -32,14 +33,15 @@ func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, clea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Limiter) Start() {
|
|
||||||
go l.Manage()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
|
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
|
||||||
l.cleanupTicker.Reset(new)
|
l.cleanupTicker.Reset(new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) Stop() {
|
||||||
|
l.stop <- struct{}{}
|
||||||
|
log.Info().Msg("Stopped Ratelimits")
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Limiter) Manage() {
|
func (l *Limiter) Manage() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -77,6 +79,8 @@ func (l *Limiter) Manage() {
|
|||||||
l.rwLock.Unlock()
|
l.rwLock.Unlock()
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
|
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
|
||||||
|
case <- l.stop:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +97,7 @@ func (l *Limiter) AddIfExists(ip string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
|
func (l *Limiter) Use(next http.Handler) http.Handler {
|
||||||
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
|
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
addr := strings.Split(r.RemoteAddr, ":")[0]
|
addr := strings.Split(r.RemoteAddr, ":")[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user