diff --git a/.gitignore b/.gitignore index fa92c96..86fad05 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ *.crt *.key -domain-router +bin/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..015dda0 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +run: build + sudo ./bin/domain-router --pretty --log-level debug + +build: + go build -o bin/domain-router cmd/domain-router/main.go diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go new file mode 100644 index 0000000..a99acc8 --- /dev/null +++ b/cmd/domain-router/main.go @@ -0,0 +1,128 @@ +package main + +import ( + "bufio" + "crypto/tls" + "errors" + "flag" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" + + domainrouter "github.com/pablu23/domain-router" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gopkg.in/natefinch/lumberjack.v2" +) + +var ( + configFileFlag = flag.String("config", "domains.conf", "Path to Domain config file") + certFlag = flag.String("cert", "", "Path to cert file") + keyFlag = flag.String("key", "", "Path to key file") + portFlag = flag.Int("port", 80, "Port") + prettyLogsFlag = flag.Bool("pretty", false, "Pretty print? Default is json") + logPathFlag = flag.String("log", "", "Path to logfile, default is stderr") + logLevelFlag = flag.String("log-level", "info", "Log Level") +) + +func main() { + flag.Parse() + + setupLogging() + + domains, err := loadConfig(*configFileFlag) + if err != nil { + log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + router := domainrouter.New(domains, client) + mux := http.NewServeMux() + mux.HandleFunc("/", router.Route) + + limiter := domainrouter.NewLimiter(3, 250, 1*time.Minute) + limiter.Start() + + server := http.Server{ + Addr: fmt.Sprintf(":%d", *portFlag), + Handler: limiter.RateLimiter(domainrouter.RequestLogger(mux)), + } + + if *certFlag != "" && *keyFlag != "" { + server.TLSConfig = &tls.Config{ + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(*certFlag, *keyFlag) + if err != nil { + return nil, err + } + return &cert, err + }, + } + + log.Info().Int("port", *portFlag).Str("cert", *certFlag).Str("key", *keyFlag).Msg("Starting server") + err := server.ListenAndServeTLS("", "") + log.Fatal().Err(err).Str("cert", *certFlag).Str("key", *keyFlag).Int("port", *portFlag).Msg("Could not start server") + } else { + log.Info().Int("port", *portFlag).Msg("Starting server") + err := server.ListenAndServe() + log.Fatal().Err(err).Int("port", *portFlag).Msg("Could not start server") + } +} + +func setupLogging() { + logLevel, err := zerolog.ParseLevel(*logLevelFlag) + if err != nil { + log.Fatal().Err(err).Str("level", *logLevelFlag).Msg("Could not parse string to level") + } + + zerolog.SetGlobalLevel(logLevel) + if *prettyLogsFlag { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + } + + if *logPathFlag != "" { + var console io.Writer = os.Stderr + if *prettyLogsFlag { + console = zerolog.ConsoleWriter{Out: os.Stderr} + } + log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ + Filename: *logPathFlag, + MaxAge: 14, + MaxBackups: 10, + })) + } +} + +func loadConfig(path string) (map[string]int, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + m := make(map[string]int) + for scanner.Scan() { + line := scanner.Text() + params := strings.Split(line, ";") + if len(params) <= 1 { + return nil, errors.New("Line does not contain enough Parameters") + } + port, err := strconv.Atoi(params[1]) + if err != nil { + return nil, err + } + m[params[0]] = port + } + + return m, nil +} diff --git a/rate-limit.go b/rate-limit.go index 479128a..337d7c0 100644 --- a/rate-limit.go +++ b/rate-limit.go @@ -4,13 +4,14 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/rs/zerolog/log" ) type Limiter struct { - current map[string]int + current map[string]*atomic.Int64 max int ticker *time.Ticker refill int @@ -20,7 +21,7 @@ type Limiter struct { func NewLimiter(maxRequests int, refills int, refillInterval time.Duration) Limiter { return Limiter{ - current: make(map[string]int), + current: make(map[string]*atomic.Int64), max: maxRequests, ticker: time.NewTicker(refillInterval), refill: refills, @@ -39,30 +40,25 @@ func (l *Limiter) Manage() { select { case ip := <-l.c: l.m.Lock() - if _, ok := l.current[ip]; ok { - l.current[ip] += 1 + if counter, ok := l.current[ip]; ok { + counter.Add(1) } else { - l.current[ip] = 1 + counter := &atomic.Int64{} + l.current[ip] = counter } l.m.Unlock() case <-l.ticker.C: - l.m.Lock() - start := time.Now() - count := len(l.current) - deleted := 0 - for ip, times := range l.current { - if times-l.refill <= 0 { - deleted += 1 - delete(l.current, ip) - } else { - l.current[ip] -= l.refill + l.m.RLock() + for ip := range l.current { + n := l.current[ip].Add(int64(-l.refill)) + if n < 0 { + l.current[ip].Store(0) + n = 0 } + log.Debug().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") } - l.m.Unlock() - elapsed := time.Since(start) - if count >= 1 { - log.Info().Int("ips", count).Int("forgotten", deleted).Str("duration", elapsed.String()).Msg("Refill rate limit") - } + l.m.RUnlock() + log.Debug().Msg("Refreshed Limits") } } } @@ -73,7 +69,7 @@ func (l *Limiter) RateLimiter(next http.Handler) http.Handler { l.m.RLock() count, ok := l.current[addr] l.m.RUnlock() - if ok && count >= l.max { + if ok && int(count.Load()) >= l.max { hj, ok := w.(http.Hijacker) if !ok { r.Body.Close() @@ -82,7 +78,7 @@ func (l *Limiter) RateLimiter(next http.Handler) http.Handler { } conn, _, err := hj.Hijack() if err != nil { - panic(err) + log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") } log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") diff --git a/router.go b/router.go index bdb13ed..e62b854 100644 --- a/router.go +++ b/router.go @@ -7,33 +7,18 @@ import ( "net/http" "net/http/httputil" + "github.com/pablu23/domain-router/util" "github.com/rs/zerolog/log" ) -// ConstMap for disallowing change of elements during runtime, for threadsafty -type constMap[K comparable, V any] struct { - dirty map[K]V -} - -func NewConstMap[K comparable, V any](m map[K]V) *constMap[K, V] { - return &constMap[K, V]{ - dirty: m, - } -} - -func (m *constMap[K, V]) Get(key K) (value V, ok bool) { - value, ok = m.dirty[key] - return value, ok -} - type Router struct { - domains *constMap[string, int] + domains *util.ImmutableMap[string, int] client *http.Client } func New(domains map[string]int, client *http.Client) Router { return Router{ - domains: NewConstMap(domains), + domains: util.NewImmutableMap(domains), client: client, } } @@ -62,6 +47,7 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { req.Header.Set(name, value) } } + req.Header.Set("X-Forwarded-For", r.RemoteAddr) for _, cookie := range r.Cookies() { req.AddCookie(cookie) @@ -80,7 +66,6 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { cookies := res.Cookies() for _, cookie := range cookies { - // fmt.Printf("Setting cookie, Name: %s, Value: %s\n", cookie.Name, cookie.Value) http.SetCookie(w, cookie) } @@ -116,14 +101,14 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { } func dumpRequest(w http.ResponseWriter, r *http.Request) bool { - if e := log.Debug(); e.Enabled() && r.Method == "POST" { + if e := log.Trace(); e.Enabled() && r.Method == "POST" { rDump, err := httputil.DumpRequest(r, true) if err != nil { log.Error().Err(err).Msg("Could not dump request") w.WriteHeader(http.StatusInternalServerError) return false } - log.Debug().Str("dump", string(rDump)).Send() + log.Trace().Str("dump", string(rDump)).Msg("Dumping Request") } return true } @@ -136,7 +121,7 @@ func dumpResponse(w http.ResponseWriter, r *http.Response) bool { w.WriteHeader(http.StatusInternalServerError) return false } - log.Trace().Str("dump", string(dump)).Send() + log.Trace().Str("dump", string(dump)).Msg("Dumping Response") } return true } diff --git a/util/constmap.go b/util/constmap.go new file mode 100644 index 0000000..0d18235 --- /dev/null +++ b/util/constmap.go @@ -0,0 +1,17 @@ +package util + +// ImmutableMap for disallowing change of elements during runtime, for threadsafty +type ImmutableMap[K comparable, V any] struct { + dirty map[K]V +} + +func NewImmutableMap[K comparable, V any](m map[K]V) *ImmutableMap[K, V] { + return &ImmutableMap[K, V]{ + dirty: m, + } +} + +func (m *ImmutableMap[K, V]) Get(key K) (value V, ok bool) { + value, ok = m.dirty[key] + return value, ok +}