Add middleware
This commit is contained in:
23
README.md
Normal file
23
README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Domain Router
|
||||||
|
|
||||||
|
Reverse Proxy for routing subdomains to different ports on same host machine
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
```csv
|
||||||
|
test.pablu.de;8181
|
||||||
|
manga.pablu.de;8282
|
||||||
|
pablu.de;8080
|
||||||
|
<Url>;<local Port>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building
|
||||||
|
|
||||||
|
### Build executable
|
||||||
|
```sh
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Runnging with default config
|
||||||
|
```sh
|
||||||
|
make run
|
||||||
|
```
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
domainrouter "github.com/pablu23/domain-router"
|
domainrouter "github.com/pablu23/domain-router"
|
||||||
|
"github.com/pablu23/domain-router/middleware"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
@@ -49,12 +50,17 @@ func main() {
|
|||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", router.Route)
|
mux.HandleFunc("/", router.Route)
|
||||||
|
|
||||||
limiter := domainrouter.NewLimiter(3, 250, 1*time.Minute)
|
limiter := middleware.NewLimiter(10, 250, 30*time.Second, 1*time.Minute)
|
||||||
limiter.Start()
|
limiter.Start()
|
||||||
|
|
||||||
|
pipeline := middleware.Pipeline(
|
||||||
|
limiter.RateLimiter,
|
||||||
|
middleware.RequestLogger,
|
||||||
|
)
|
||||||
|
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", *portFlag),
|
Addr: fmt.Sprintf(":%d", *portFlag),
|
||||||
Handler: limiter.RateLimiter(domainrouter.RequestLogger(mux)),
|
Handler: pipeline(mux),
|
||||||
}
|
}
|
||||||
|
|
||||||
if *certFlag != "" && *keyFlag != "" {
|
if *certFlag != "" && *keyFlag != "" {
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/pablu23/domain-router
|
module github.com/pablu23/domain-router
|
||||||
|
|
||||||
go 1.22.3
|
go 1.23
|
||||||
|
|
||||||
require github.com/rs/zerolog v1.33.0
|
require github.com/rs/zerolog v1.33.0
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package domainrouter
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func RequestLogger(next http.Handler) http.Handler {
|
func RequestLogger(next http.Handler) http.Handler {
|
||||||
|
log.Info().Msg("Enabling Logging")
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
18
middleware/pipeline.go
Normal file
18
middleware/pipeline.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Middleware func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
func Pipeline(funcs ...Middleware) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
for _, m := range slices.Backward(funcs) {
|
||||||
|
next = m(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
}
|
||||||
110
middleware/rate-limit.go
Normal file
110
middleware/rate-limit.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Limiter struct {
|
||||||
|
currentBuckets map[string]*atomic.Int64
|
||||||
|
bucketSize int
|
||||||
|
refillTicker *time.Ticker
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
bucketRefill int
|
||||||
|
rwLock *sync.RWMutex
|
||||||
|
rateChannel chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
|
||||||
|
return Limiter{
|
||||||
|
currentBuckets: make(map[string]*atomic.Int64),
|
||||||
|
bucketSize: maxRequests,
|
||||||
|
refillTicker: time.NewTicker(refillInterval),
|
||||||
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
|
bucketRefill: refills,
|
||||||
|
rwLock: &sync.RWMutex{},
|
||||||
|
rateChannel: make(chan string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) Start() {
|
||||||
|
go l.Manage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
|
||||||
|
l.cleanupTicker.Reset(new)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) Manage() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ip := <-l.rateChannel:
|
||||||
|
l.rwLock.Lock()
|
||||||
|
if counter, ok := l.currentBuckets[ip]; ok {
|
||||||
|
counter.Add(1)
|
||||||
|
} else {
|
||||||
|
counter := &atomic.Int64{}
|
||||||
|
l.currentBuckets[ip] = counter
|
||||||
|
}
|
||||||
|
l.rwLock.Unlock()
|
||||||
|
case <-l.refillTicker.C:
|
||||||
|
l.rwLock.RLock()
|
||||||
|
for ip := range l.currentBuckets {
|
||||||
|
n := l.currentBuckets[ip].Add(int64(-l.bucketRefill))
|
||||||
|
if n < 0 {
|
||||||
|
l.currentBuckets[ip].Store(0)
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
|
||||||
|
}
|
||||||
|
l.rwLock.RUnlock()
|
||||||
|
log.Trace().Msg("Refreshed Limits")
|
||||||
|
case <-l.cleanupTicker.C:
|
||||||
|
l.rwLock.Lock()
|
||||||
|
deletedBuckets := 0
|
||||||
|
for ip := range l.currentBuckets {
|
||||||
|
if l.currentBuckets[ip].Load() <= 0 {
|
||||||
|
delete(l.currentBuckets, ip)
|
||||||
|
deletedBuckets += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.rwLock.Unlock()
|
||||||
|
log.Debug().Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
|
||||||
|
log.Info().Msg("Enabling Ratelimits")
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
addr := strings.Split(r.RemoteAddr, ":")[0]
|
||||||
|
l.rwLock.RLock()
|
||||||
|
count, ok := l.currentBuckets[addr]
|
||||||
|
l.rwLock.RUnlock()
|
||||||
|
if ok && int(count.Load()) >= l.bucketSize {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
r.Body.Close()
|
||||||
|
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.rateChannel <- addr
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
package domainrouter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Limiter struct {
|
|
||||||
current map[string]*atomic.Int64
|
|
||||||
max int
|
|
||||||
ticker *time.Ticker
|
|
||||||
refill int
|
|
||||||
m *sync.RWMutex
|
|
||||||
c chan string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration) Limiter {
|
|
||||||
return Limiter{
|
|
||||||
current: make(map[string]*atomic.Int64),
|
|
||||||
max: maxRequests,
|
|
||||||
ticker: time.NewTicker(refillInterval),
|
|
||||||
refill: refills,
|
|
||||||
m: &sync.RWMutex{},
|
|
||||||
c: make(chan string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Limiter) Start() {
|
|
||||||
go l.Manage()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Limiter) Manage() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case ip := <-l.c:
|
|
||||||
l.m.Lock()
|
|
||||||
if counter, ok := l.current[ip]; ok {
|
|
||||||
counter.Add(1)
|
|
||||||
} else {
|
|
||||||
counter := &atomic.Int64{}
|
|
||||||
l.current[ip] = counter
|
|
||||||
}
|
|
||||||
l.m.Unlock()
|
|
||||||
case <-l.ticker.C:
|
|
||||||
l.m.RLock()
|
|
||||||
for ip := range l.current {
|
|
||||||
n := l.current[ip].Add(int64(-l.refill))
|
|
||||||
if n < 0 {
|
|
||||||
l.current[ip].Store(0)
|
|
||||||
n = 0
|
|
||||||
}
|
|
||||||
log.Debug().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
|
|
||||||
}
|
|
||||||
l.m.RUnlock()
|
|
||||||
log.Debug().Msg("Refreshed Limits")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
addr := strings.Split(r.RemoteAddr, ":")[0]
|
|
||||||
l.m.RLock()
|
|
||||||
count, ok := l.current[addr]
|
|
||||||
l.m.RUnlock()
|
|
||||||
if ok && int(count.Load()) >= l.max {
|
|
||||||
hj, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
r.Body.Close()
|
|
||||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conn, _, err := hj.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.c <- addr
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user