Add middleware
This commit is contained in:
23
README.md
Normal file
23
README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Domain Router
|
||||
|
||||
Reverse Proxy for routing subdomains to different ports on same host machine
|
||||
|
||||
## Configuration
|
||||
```csv
|
||||
test.pablu.de;8181
|
||||
manga.pablu.de;8282
|
||||
pablu.de;8080
|
||||
<Url>;<local Port>
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Build executable
|
||||
```sh
|
||||
make build
|
||||
```
|
||||
|
||||
### Runnging with default config
|
||||
```sh
|
||||
make run
|
||||
```
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
domainrouter "github.com/pablu23/domain-router"
|
||||
"github.com/pablu23/domain-router/middleware"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -49,12 +50,17 @@ func main() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", router.Route)
|
||||
|
||||
limiter := domainrouter.NewLimiter(3, 250, 1*time.Minute)
|
||||
limiter := middleware.NewLimiter(10, 250, 30*time.Second, 1*time.Minute)
|
||||
limiter.Start()
|
||||
|
||||
pipeline := middleware.Pipeline(
|
||||
limiter.RateLimiter,
|
||||
middleware.RequestLogger,
|
||||
)
|
||||
|
||||
server := http.Server{
|
||||
Addr: fmt.Sprintf(":%d", *portFlag),
|
||||
Handler: limiter.RateLimiter(domainrouter.RequestLogger(mux)),
|
||||
Handler: pipeline(mux),
|
||||
}
|
||||
|
||||
if *certFlag != "" && *keyFlag != "" {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/pablu23/domain-router
|
||||
|
||||
go 1.22.3
|
||||
go 1.23
|
||||
|
||||
require github.com/rs/zerolog v1.33.0
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package domainrouter
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
log.Info().Msg("Enabling Logging")
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
18
middleware/pipeline.go
Normal file
18
middleware/pipeline.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
func Pipeline(funcs ...Middleware) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
for _, m := range slices.Backward(funcs) {
|
||||
next = m(next)
|
||||
}
|
||||
|
||||
return next
|
||||
}
|
||||
}
|
||||
110
middleware/rate-limit.go
Normal file
110
middleware/rate-limit.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
currentBuckets map[string]*atomic.Int64
|
||||
bucketSize int
|
||||
refillTicker *time.Ticker
|
||||
cleanupTicker *time.Ticker
|
||||
bucketRefill int
|
||||
rwLock *sync.RWMutex
|
||||
rateChannel chan string
|
||||
}
|
||||
|
||||
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
|
||||
return Limiter{
|
||||
currentBuckets: make(map[string]*atomic.Int64),
|
||||
bucketSize: maxRequests,
|
||||
refillTicker: time.NewTicker(refillInterval),
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
bucketRefill: refills,
|
||||
rwLock: &sync.RWMutex{},
|
||||
rateChannel: make(chan string),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Start() {
|
||||
go l.Manage()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
|
||||
l.cleanupTicker.Reset(new)
|
||||
}
|
||||
|
||||
func (l *Limiter) Manage() {
|
||||
for {
|
||||
select {
|
||||
case ip := <-l.rateChannel:
|
||||
l.rwLock.Lock()
|
||||
if counter, ok := l.currentBuckets[ip]; ok {
|
||||
counter.Add(1)
|
||||
} else {
|
||||
counter := &atomic.Int64{}
|
||||
l.currentBuckets[ip] = counter
|
||||
}
|
||||
l.rwLock.Unlock()
|
||||
case <-l.refillTicker.C:
|
||||
l.rwLock.RLock()
|
||||
for ip := range l.currentBuckets {
|
||||
n := l.currentBuckets[ip].Add(int64(-l.bucketRefill))
|
||||
if n < 0 {
|
||||
l.currentBuckets[ip].Store(0)
|
||||
n = 0
|
||||
}
|
||||
log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
|
||||
}
|
||||
l.rwLock.RUnlock()
|
||||
log.Trace().Msg("Refreshed Limits")
|
||||
case <-l.cleanupTicker.C:
|
||||
l.rwLock.Lock()
|
||||
deletedBuckets := 0
|
||||
for ip := range l.currentBuckets {
|
||||
if l.currentBuckets[ip].Load() <= 0 {
|
||||
delete(l.currentBuckets, ip)
|
||||
deletedBuckets += 1
|
||||
}
|
||||
}
|
||||
l.rwLock.Unlock()
|
||||
log.Debug().Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
|
||||
log.Info().Msg("Enabling Ratelimits")
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
addr := strings.Split(r.RemoteAddr, ":")[0]
|
||||
l.rwLock.RLock()
|
||||
count, ok := l.currentBuckets[addr]
|
||||
l.rwLock.RUnlock()
|
||||
if ok && int(count.Load()) >= l.bucketSize {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
r.Body.Close()
|
||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
|
||||
}
|
||||
|
||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
l.rateChannel <- addr
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package domainrouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
current map[string]*atomic.Int64
|
||||
max int
|
||||
ticker *time.Ticker
|
||||
refill int
|
||||
m *sync.RWMutex
|
||||
c chan string
|
||||
}
|
||||
|
||||
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration) Limiter {
|
||||
return Limiter{
|
||||
current: make(map[string]*atomic.Int64),
|
||||
max: maxRequests,
|
||||
ticker: time.NewTicker(refillInterval),
|
||||
refill: refills,
|
||||
m: &sync.RWMutex{},
|
||||
c: make(chan string),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Start() {
|
||||
go l.Manage()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Limiter) Manage() {
|
||||
for {
|
||||
select {
|
||||
case ip := <-l.c:
|
||||
l.m.Lock()
|
||||
if counter, ok := l.current[ip]; ok {
|
||||
counter.Add(1)
|
||||
} else {
|
||||
counter := &atomic.Int64{}
|
||||
l.current[ip] = counter
|
||||
}
|
||||
l.m.Unlock()
|
||||
case <-l.ticker.C:
|
||||
l.m.RLock()
|
||||
for ip := range l.current {
|
||||
n := l.current[ip].Add(int64(-l.refill))
|
||||
if n < 0 {
|
||||
l.current[ip].Store(0)
|
||||
n = 0
|
||||
}
|
||||
log.Debug().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
|
||||
}
|
||||
l.m.RUnlock()
|
||||
log.Debug().Msg("Refreshed Limits")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
addr := strings.Split(r.RemoteAddr, ":")[0]
|
||||
l.m.RLock()
|
||||
count, ok := l.current[addr]
|
||||
l.m.RUnlock()
|
||||
if ok && int(count.Load()) >= l.max {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
r.Body.Close()
|
||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
|
||||
}
|
||||
|
||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
l.c <- addr
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user