diff --git a/README.md b/README.md new file mode 100644 index 0000000..6871b4f --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +# Domain Router + +Reverse Proxy for routing subdomains to different ports on same host machine + +## Configuration +```csv +test.pablu.de;8181 +manga.pablu.de;8282 +pablu.de;8080 +; +``` + +## Building + +### Build executable +```sh +make build +``` + +### Runnging with default config +```sh +make run +``` diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index a99acc8..1b4341b 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -14,6 +14,7 @@ import ( "time" domainrouter "github.com/pablu23/domain-router" + "github.com/pablu23/domain-router/middleware" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "gopkg.in/natefinch/lumberjack.v2" @@ -49,12 +50,17 @@ func main() { mux := http.NewServeMux() mux.HandleFunc("/", router.Route) - limiter := domainrouter.NewLimiter(3, 250, 1*time.Minute) + limiter := middleware.NewLimiter(10, 250, 30*time.Second, 1*time.Minute) limiter.Start() + pipeline := middleware.Pipeline( + limiter.RateLimiter, + middleware.RequestLogger, + ) + server := http.Server{ Addr: fmt.Sprintf(":%d", *portFlag), - Handler: limiter.RateLimiter(domainrouter.RequestLogger(mux)), + Handler: pipeline(mux), } if *certFlag != "" && *keyFlag != "" { diff --git a/go.mod b/go.mod index a4c8276..d351408 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pablu23/domain-router -go 1.22.3 +go 1.23 require github.com/rs/zerolog v1.33.0 diff --git a/logging-middleware.go b/middleware/logging.go similarity index 93% rename from logging-middleware.go rename to middleware/logging.go index a09da3b..1d09918 100644 --- a/logging-middleware.go +++ b/middleware/logging.go @@ -1,4 +1,4 @@ -package domainrouter +package middleware import ( "net/http" @@ -9,6 +9,7 @@ import ( ) func RequestLogger(next http.Handler) http.Handler { + log.Info().Msg("Enabling Logging") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/middleware/pipeline.go b/middleware/pipeline.go new file mode 100644 index 0000000..f95a2d8 --- /dev/null +++ b/middleware/pipeline.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" + "slices" +) + +type Middleware func(http.Handler) http.Handler + +func Pipeline(funcs ...Middleware) Middleware { + return func(next http.Handler) http.Handler { + for _, m := range slices.Backward(funcs) { + next = m(next) + } + + return next + } +} diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go new file mode 100644 index 0000000..b4a197c --- /dev/null +++ b/middleware/rate-limit.go @@ -0,0 +1,110 @@ +package middleware + +import ( + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog/log" +) + +type Limiter struct { + currentBuckets map[string]*atomic.Int64 + bucketSize int + refillTicker *time.Ticker + cleanupTicker *time.Ticker + bucketRefill int + rwLock *sync.RWMutex + rateChannel chan string +} + +func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { + return Limiter{ + currentBuckets: make(map[string]*atomic.Int64), + bucketSize: maxRequests, + refillTicker: time.NewTicker(refillInterval), + cleanupTicker: time.NewTicker(cleanupInterval), + bucketRefill: refills, + rwLock: &sync.RWMutex{}, + rateChannel: make(chan string), + } +} + +func (l *Limiter) Start() { + go l.Manage() + return +} + +func (l *Limiter) UpdateCleanupTime(new time.Duration) { + l.cleanupTicker.Reset(new) +} + +func (l *Limiter) Manage() { + for { + select { + case ip := <-l.rateChannel: + l.rwLock.Lock() + if counter, ok := l.currentBuckets[ip]; ok { + counter.Add(1) + } else { + counter := &atomic.Int64{} + l.currentBuckets[ip] = counter + } + l.rwLock.Unlock() + case <-l.refillTicker.C: + l.rwLock.RLock() + for ip := range l.currentBuckets { + n := l.currentBuckets[ip].Add(int64(-l.bucketRefill)) + if n < 0 { + l.currentBuckets[ip].Store(0) + n = 0 + } + log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") + } + l.rwLock.RUnlock() + log.Trace().Msg("Refreshed Limits") + case <-l.cleanupTicker.C: + l.rwLock.Lock() + deletedBuckets := 0 + for ip := range l.currentBuckets { + if l.currentBuckets[ip].Load() <= 0 { + delete(l.currentBuckets, ip) + deletedBuckets += 1 + } + } + l.rwLock.Unlock() + log.Debug().Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") + } + } + +} + +func (l *Limiter) RateLimiter(next http.Handler) http.Handler { + log.Info().Msg("Enabling Ratelimits") + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + addr := strings.Split(r.RemoteAddr, ":")[0] + l.rwLock.RLock() + count, ok := l.currentBuckets[addr] + l.rwLock.RUnlock() + if ok && int(count.Load()) >= l.bucketSize { + hj, ok := w.(http.Hijacker) + if !ok { + r.Body.Close() + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + return + } + conn, _, err := hj.Hijack() + if err != nil { + log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") + } + + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + conn.Close() + return + } + l.rateChannel <- addr + next.ServeHTTP(w, r) + }) +} diff --git a/rate-limit.go b/rate-limit.go deleted file mode 100644 index 337d7c0..0000000 --- a/rate-limit.go +++ /dev/null @@ -1,91 +0,0 @@ -package domainrouter - -import ( - "net/http" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/rs/zerolog/log" -) - -type Limiter struct { - current map[string]*atomic.Int64 - max int - ticker *time.Ticker - refill int - m *sync.RWMutex - c chan string -} - -func NewLimiter(maxRequests int, refills int, refillInterval time.Duration) Limiter { - return Limiter{ - current: make(map[string]*atomic.Int64), - max: maxRequests, - ticker: time.NewTicker(refillInterval), - refill: refills, - m: &sync.RWMutex{}, - c: make(chan string), - } -} - -func (l *Limiter) Start() { - go l.Manage() - return -} - -func (l *Limiter) Manage() { - for { - select { - case ip := <-l.c: - l.m.Lock() - if counter, ok := l.current[ip]; ok { - counter.Add(1) - } else { - counter := &atomic.Int64{} - l.current[ip] = counter - } - l.m.Unlock() - case <-l.ticker.C: - l.m.RLock() - for ip := range l.current { - n := l.current[ip].Add(int64(-l.refill)) - if n < 0 { - l.current[ip].Store(0) - n = 0 - } - log.Debug().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") - } - l.m.RUnlock() - log.Debug().Msg("Refreshed Limits") - } - } -} - -func (l *Limiter) RateLimiter(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addr := strings.Split(r.RemoteAddr, ":")[0] - l.m.RLock() - count, ok := l.current[addr] - l.m.RUnlock() - if ok && int(count.Load()) >= l.max { - hj, ok := w.(http.Hijacker) - if !ok { - r.Body.Close() - log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") - return - } - conn, _, err := hj.Hijack() - if err != nil { - log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") - } - - log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") - conn.Close() - return - } - l.c <- addr - next.ServeHTTP(w, r) - }) -}