diff --git a/acme/acme.go b/acme/acme.go new file mode 100644 index 0000000..f047980 --- /dev/null +++ b/acme/acme.go @@ -0,0 +1,111 @@ +package acme + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net/http" + "os" + + "github.com/go-acme/lego/v4/certcrypto" + "github.com/go-acme/lego/v4/certificate" + "github.com/go-acme/lego/v4/challenge/http01" + "github.com/go-acme/lego/v4/challenge/tlsalpn01" + "github.com/go-acme/lego/v4/lego" + "github.com/go-acme/lego/v4/registration" + domainrouter "github.com/pablu23/domain-router" +) + +func SetupAcme(config *domainrouter.Config) error { + acme := config.Server.Ssl.Acme + + var privateKey *ecdsa.PrivateKey + if _, err := os.Stat(acme.KeyFile); errors.Is(err, os.ErrNotExist) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + err = os.WriteFile(acme.KeyFile, []byte(encode(privateKey)), 0666) + if err != nil { + return err + } + } else { + keyBytes, err := os.ReadFile(acme.KeyFile) + if err != nil { + return err + } + privateKey = decode(string(keyBytes)) + } + + user := User{ + Email: acme.Email, + key: privateKey, + } + + leConfig := lego.NewConfig(&user) + leConfig.CADirURL = acme.CADirURL + leConfig.Certificate.KeyType = certcrypto.RSA2048 + leConfig.HTTPClient.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + client, err := lego.NewClient(leConfig) + if err != nil { + return err + } + + // strconv.Itoa(config.Server.Port) + err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "5002")) + if err != nil { + return err + } + + err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "5001")) + if err != nil { + return err + } + + reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) + if err != nil { + return err + } + user.Registration = reg + + domains := make([]string, 0) + for _, host := range config.Hosts { + domains = append(domains, host.Domains...) + } + + request := certificate.ObtainRequest{ + Domains: domains, + Bundle: true, + } + + certificates, err := client.Certificate.Obtain(request) + if err != nil { + return err + } + + fmt.Printf("%#v\n", certificates) + return nil +} + +func encode(privateKey *ecdsa.PrivateKey) string { + x509Encoded, _ := x509.MarshalECPrivateKey(privateKey) + pemEncoded := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: x509Encoded}) + + return string(pemEncoded) +} + +func decode(pemEncoded string) *ecdsa.PrivateKey { + block, _ := pem.Decode([]byte(pemEncoded)) + x509Encoded := block.Bytes + privateKey, _ := x509.ParseECPrivateKey(x509Encoded) + + return privateKey +} diff --git a/acme/user.go b/acme/user.go new file mode 100644 index 0000000..d4ee407 --- /dev/null +++ b/acme/user.go @@ -0,0 +1,25 @@ +package acme + +import ( + "crypto" + + "github.com/go-acme/lego/v4/registration" +) + +type User struct { + Email string + Registration *registration.Resource + key crypto.PrivateKey +} + +func (u *User) GetEmail() string { + return u.Email +} + +func (u *User) GetRegistration() *registration.Resource { + return u.Registration +} + +func (u *User) GetPrivateKey() crypto.PrivateKey { + return u.key +} diff --git a/cmd/domain-router/main.go b/cmd/domain-router/main.go index 540cc14..5398dbc 100644 --- a/cmd/domain-router/main.go +++ b/cmd/domain-router/main.go @@ -1,148 +1,156 @@ -package main - -import ( - "crypto/tls" - "flag" - "fmt" - "io" - "net/http" - "net/url" - "os" - "time" - - domainrouter "github.com/pablu23/domain-router" - "github.com/pablu23/domain-router/middleware" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "gopkg.in/natefinch/lumberjack.v2" - "gopkg.in/yaml.v3" -) - -var ( - configFileFlag = flag.String("config", "config.yaml", "Path to config file") -) - -func main() { - flag.Parse() - - config, err := loadConfig(*configFileFlag) - if err != nil { - log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") - } - - setupLogging(config) - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - router := domainrouter.New(config, client) - mux := http.NewServeMux() - mux.HandleFunc("/", router.Route) - - if config.General.AnnouncePublic { - h, err := url.JoinPath("/", config.General.HealthEndpoint) - if err != nil { - log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path") - h = "/healthz" - } - mux.HandleFunc(h, router.Healthz) - } - - pipeline := configureMiddleware(config) - - server := http.Server{ - Addr: fmt.Sprintf(":%d", config.Server.Port), - Handler: pipeline(mux), - } - - if config.Server.Ssl.Enabled { - server.TLSConfig = &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile) - if err != nil { - return nil, err - } - return &cert, err - }, - } - - log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server") - err := server.ListenAndServeTLS("", "") - log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") - } else { - log.Info().Int("port", config.Server.Port).Msg("Starting server") - err := server.ListenAndServe() - log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") - } -} - -func configureMiddleware(config *domainrouter.Config) middleware.Middleware { - middlewares := make([]middleware.Middleware, 0) - - if config.RateLimit.Enabled { - refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) - if err != nil { - log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker") - } - - cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker) - if err != nil { - log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") - } - limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) - limiter.Start() - middlewares = append(middlewares, limiter.RateLimiter) - } - - if config.Logging.Requests { - middlewares = append(middlewares, middleware.RequestLogger) - } - - pipeline := middleware.Pipeline(middlewares...) - return pipeline -} - -func setupLogging(config *domainrouter.Config) { - logLevel, err := zerolog.ParseLevel(config.Logging.Level) - if err != nil { - log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level") - } - - zerolog.SetGlobalLevel(logLevel) - log.Info().Str("level", config.Logging.Level).Msg("Set logging level") - if config.Logging.Pretty { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - } - - if config.Logging.File.Enabled { - var console io.Writer = os.Stderr - if config.Logging.Pretty { - console = zerolog.ConsoleWriter{Out: os.Stderr} - } - - log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ - Filename: config.Logging.File.Path, - MaxAge: config.Logging.File.MaxAge, - MaxBackups: config.Logging.File.MaxBackups, - })) - } -} - -func loadConfig(path string) (*domainrouter.Config, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - var cfg domainrouter.Config - decoder := yaml.NewDecoder(f) - err = decoder.Decode(&cfg) - if err != nil { - return nil, err - } - - return &cfg, err -} +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + domainrouter "github.com/pablu23/domain-router" + "github.com/pablu23/domain-router/acme" + "github.com/pablu23/domain-router/middleware" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gopkg.in/natefinch/lumberjack.v2" + "gopkg.in/yaml.v3" +) + +var ( + configFileFlag = flag.String("config", "config.yaml", "Path to config file") +) + +func main() { + flag.Parse() + + config, err := loadConfig(*configFileFlag) + if err != nil { + log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") + } + + setupLogging(config) + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + router := domainrouter.New(config, client) + mux := http.NewServeMux() + mux.HandleFunc("/", router.Route) + + if config.General.AnnouncePublic { + h, err := url.JoinPath("/", config.General.HealthEndpoint) + if err != nil { + log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path") + h = "/healthz" + } + mux.HandleFunc(h, router.Healthz) + } + + pipeline := configureMiddleware(config) + + server := http.Server{ + Addr: fmt.Sprintf(":%d", config.Server.Port), + Handler: pipeline(mux), + } + + if config.Server.Ssl.Enabled { + server.TLSConfig = &tls.Config{ + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile) + if err != nil { + return nil, err + } + return &cert, err + }, + } + + if config.Server.Ssl.Acme.Enabled { + err := acme.SetupAcme(config) + if err != nil { + log.Fatal().Err(err).Msg("unable to setup acme") + } + } + + log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server") + err := server.ListenAndServeTLS("", "") + log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") + } else { + log.Info().Int("port", config.Server.Port).Msg("Starting server") + err := server.ListenAndServe() + log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") + } +} + +func configureMiddleware(config *domainrouter.Config) middleware.Middleware { + middlewares := make([]middleware.Middleware, 0) + + if config.RateLimit.Enabled { + refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) + if err != nil { + log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker") + } + + cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker) + if err != nil { + log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") + } + limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) + limiter.Start() + middlewares = append(middlewares, limiter.RateLimiter) + } + + if config.Logging.Requests { + middlewares = append(middlewares, middleware.RequestLogger) + } + + pipeline := middleware.Pipeline(middlewares...) + return pipeline +} + +func setupLogging(config *domainrouter.Config) { + logLevel, err := zerolog.ParseLevel(config.Logging.Level) + if err != nil { + log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level") + } + + zerolog.SetGlobalLevel(logLevel) + log.Info().Str("level", config.Logging.Level).Msg("Set logging level") + if config.Logging.Pretty { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + } + + if config.Logging.File.Enabled { + var console io.Writer = os.Stderr + if config.Logging.Pretty { + console = zerolog.ConsoleWriter{Out: os.Stderr} + } + + log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ + Filename: config.Logging.File.Path, + MaxAge: config.Logging.File.MaxAge, + MaxBackups: config.Logging.File.MaxBackups, + })) + } +} + +func loadConfig(path string) (*domainrouter.Config, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var cfg domainrouter.Config + decoder := yaml.NewDecoder(f) + err = decoder.Decode(&cfg) + if err != nil { + return nil, err + } + + return &cfg, err +} diff --git a/config.go b/config.go index 4a3fe25..c2f70cc 100644 --- a/config.go +++ b/config.go @@ -1,41 +1,47 @@ -package domainrouter - -type Config struct { - General struct { - AnnouncePublic bool `yaml:"announce"` - HealthEndpoint string `yaml:"healthz"` - } `yaml:"general"` - Server struct { - Port int `yaml:"port"` - Ssl struct { - Enabled bool `yaml:"enabled"` - CertFile string `yaml:"certFile"` - KeyFile string `yaml:"keyFile"` - } `yaml:"ssl"` - } `yaml:"server"` - Hosts []struct { - Port int `yaml:"port"` - Remotes []string `yaml:"remotes"` - Domains []string `yaml:"domains"` - Public bool `yaml:"public"` - Secure bool `yaml:"secure"` - } `yaml:"hosts"` - RateLimit struct { - Enabled bool `yaml:"enabled"` - BucketSize int `yaml:"bucketSize"` - RefillTicker string `yaml:"refillTime"` - CleanupTicker string `yaml:"cleanupTime"` - BucketRefill int `yaml:"refillSize"` - } `yaml:"rateLimit"` - Logging struct { - Level string `yaml:"level"` - Pretty bool `yaml:"pretty"` - Requests bool `yaml:"requests"` - File struct { - Enabled bool `yaml:"enabled"` - Path string `yaml:"path"` - MaxAge int `yaml:"maxAge"` - MaxBackups int `yamls:"maxBackups"` - } `yaml:"file"` - } `yaml:"logging"` -} +package domainrouter + +type Config struct { + General struct { + AnnouncePublic bool `yaml:"announce"` + HealthEndpoint string `yaml:"healthz"` + } `yaml:"general"` + Server struct { + Port int `yaml:"port"` + Ssl struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"certFile"` + KeyFile string `yaml:"keyFile"` + Acme struct { + Enabled bool `yaml:"enabled"` + Email string `yaml:"email"` + KeyFile string `yaml:"keyFile"` + CADirURL string `yaml:"caDirUrl"` + } `yaml:"acme"` + } `yaml:"ssl"` + } `yaml:"server"` + Hosts []struct { + Port int `yaml:"port"` + Remotes []string `yaml:"remotes"` + Domains []string `yaml:"domains"` + Public bool `yaml:"public"` + Secure bool `yaml:"secure"` + } `yaml:"hosts"` + RateLimit struct { + Enabled bool `yaml:"enabled"` + BucketSize int `yaml:"bucketSize"` + RefillTicker string `yaml:"refillTime"` + CleanupTicker string `yaml:"cleanupTime"` + BucketRefill int `yaml:"refillSize"` + } `yaml:"rateLimit"` + Logging struct { + Level string `yaml:"level"` + Pretty bool `yaml:"pretty"` + Requests bool `yaml:"requests"` + File struct { + Enabled bool `yaml:"enabled"` + Path string `yaml:"path"` + MaxAge int `yaml:"maxAge"` + MaxBackups int `yamls:"maxBackups"` + } `yaml:"file"` + } `yaml:"logging"` +} diff --git a/config.yaml b/config.yaml index c48b604..b51395e 100644 --- a/config.yaml +++ b/config.yaml @@ -1,70 +1,75 @@ -server: - port: 443 - ssl: - enabled: true - certFile: server.crt - keyFile: server.key - - -logging: - level: debug - # Pretty print for human consumption otherwise json - pretty: true - # Log incoming requests - requests: true - # Log to file aswell as stderr - file: - enabled: false - maxAge: 14 - maxBackups: 10 - path: ~/logs/router - - -rateLimit: - enabled: true - # How many requests per ip adress are allowed - bucketSize: 50 - # How many requests per ip address are refilled - refillSize: 50 - # How often requests per ip address are refilled - refillTime: 30s - # How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up) - cleanupTime: 45s - - -hosts: - # Remote address to request - - remotes: - - localhost - - 192.168.2.154 - # Port on which to request - port: 8181 - # Health check if announce is true - public: true - # Domains which get redirected to host - domains: - - localhost - - test.localhost - - - remotes: - - localhost - port: 8282 - public: false - domains: - - private.localhost - - - remotes: - - www.google.com - - localhost - port: 443 - public: false - # Uses https under the hood to communicate with the remote host - secure: true - domains: - - google.localhost - -general: - # Expose health endpoint, that requests health endpoints from hosts which are public - announce: true - # Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint - healthz: healthz +server: + port: 443 + ssl: + enabled: true + certFile: server.crt + keyFile: server.key + acme: + enabled: true + email: me@pablu.de + keyFile: userKey.key + caDirUrl: https://192.168.2.154:14000/dir + + +logging: + level: debug + # Pretty print for human consumption otherwise json + pretty: true + # Log incoming requests + requests: true + # Log to file aswell as stderr + file: + enabled: false + maxAge: 14 + maxBackups: 10 + path: ~/logs/router + + +rateLimit: + enabled: true + # How many requests per ip adress are allowed + bucketSize: 50 + # How many requests per ip address are refilled + refillSize: 50 + # How often requests per ip address are refilled + refillTime: 30s + # How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up) + cleanupTime: 45s + + +hosts: + # Remote address to request + - remotes: + - localhost + - 192.168.2.154 + # Port on which to request + port: 8181 + # Health check if announce is true + public: true + # Domains which get redirected to host + domains: + - localhost + - test.localhost + + - remotes: + - localhost + port: 8282 + public: false + domains: + - private.localhost + + # - remotes: + # - www.google.com + # - localhost + # port: 443 + # public: false + # # Uses https under the hood to communicate with the remote host + # secure: true + # domains: + # - google.localhost + +general: + # Expose health endpoint, that requests health endpoints from hosts which are public + announce: true + # Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint + healthz: healthz diff --git a/go.mod b/go.mod index 51a8f93..ebaf4ff 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,32 @@ module github.com/pablu23/domain-router -go 1.23 +go 1.23.0 -require github.com/rs/zerolog v1.33.0 +toolchain go1.24.4 + +require github.com/rs/zerolog v1.34.0 require ( + github.com/cenkalti/backoff v2.2.1+incompatible // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.5 // indirect + github.com/miekg/dns v1.1.67 // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/text v0.27.0 // indirect + golang.org/x/tools v0.34.0 // indirect + gopkg.in/square/go-jose.v2 v2.6.0 // indirect +) + +require ( + github.com/go-acme/lego v2.7.2+incompatible + github.com/go-acme/lego/v4 v4.24.0 github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/urfave/negroni v1.0.0 - golang.org/x/sys v0.12.0 // indirect + golang.org/x/sys v0.34.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 1d68d5e..862742e 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,55 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/go-acme/lego v2.7.2+incompatible h1:ThhpPBgf6oa9X/vRd0kEmWOsX7+vmYdckmGZSb+FEp0= +github.com/go-acme/lego v2.7.2+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M= +github.com/go-acme/lego/v4 v4.24.0 h1:pe0q49JKxfSGEP3lkgkMVQrZM1KbD+e0dpJ2McYsiVw= +github.com/go-acme/lego/v4 v4.24.0/go.mod h1:hkstZY6D0jylIrZbuNmEQrWQxTIfaJH7prwaWvKDOjw= +github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= +github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0= +github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 6589276..ac0e471 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -1,123 +1,122 @@ -package middleware - -import ( - "net/http" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/rs/zerolog/log" -) - -type Limiter struct { - currentBuckets map[string]*atomic.Int64 - bucketSize int - refillTicker *time.Ticker - cleanupTicker *time.Ticker - bucketRefill int - rwLock *sync.RWMutex - rateChannel chan string -} - -func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { - return Limiter{ - currentBuckets: make(map[string]*atomic.Int64), - bucketSize: maxRequests, - refillTicker: time.NewTicker(refillInterval), - cleanupTicker: time.NewTicker(cleanupInterval), - bucketRefill: refills, - rwLock: &sync.RWMutex{}, - rateChannel: make(chan string), - } -} - -func (l *Limiter) Start() { - go l.Manage() - return -} - -func (l *Limiter) UpdateCleanupTime(new time.Duration) { - l.cleanupTicker.Reset(new) -} - -func (l *Limiter) Manage() { - for { - select { - case ip := <-l.rateChannel: - if l.AddIfExists(ip) { - break - } - - l.rwLock.Lock() - counter := &atomic.Int64{} - l.currentBuckets[ip] = counter - l.rwLock.Unlock() - case <-l.refillTicker.C: - l.rwLock.RLock() - for ip := range l.currentBuckets { - n := l.currentBuckets[ip].Add(int64(-l.bucketRefill)) - if n < 0 { - l.currentBuckets[ip].Store(0) - n = 0 - } - log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") - } - l.rwLock.RUnlock() - log.Trace().Msg("Refreshed Limits") - case <-l.cleanupTicker.C: - start := time.Now() - l.rwLock.Lock() - deletedBuckets := 0 - for ip := range l.currentBuckets { - if l.currentBuckets[ip].Load() <= 0 { - delete(l.currentBuckets, ip) - deletedBuckets += 1 - } - } - l.rwLock.Unlock() - duration := time.Since(start) - log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") - } - } -} - -// Adds one if ip already exists and returns true -// If ip doesnt yet exist only returns false -func (l *Limiter) AddIfExists(ip string) bool { - l.rwLock.RLock() - defer l.rwLock.RUnlock() - if counter, ok := l.currentBuckets[ip]; ok { - counter.Add(1) - return true - } - return false -} - -func (l *Limiter) RateLimiter(next http.Handler) http.Handler { - log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addr := strings.Split(r.RemoteAddr, ":")[0] - l.rwLock.RLock() - count, ok := l.currentBuckets[addr] - l.rwLock.RUnlock() - if ok && int(count.Load()) >= l.bucketSize { - hj, ok := w.(http.Hijacker) - if !ok { - r.Body.Close() - log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") - return - } - conn, _, err := hj.Hijack() - if err != nil { - log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") - } - - log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") - conn.Close() - return - } - l.rateChannel <- addr - next.ServeHTTP(w, r) - }) -} +package middleware + +import ( + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog/log" +) + +type Limiter struct { + currentBuckets map[string]*atomic.Int64 + bucketSize int + refillTicker *time.Ticker + cleanupTicker *time.Ticker + bucketRefill int + rwLock *sync.RWMutex + rateChannel chan string +} + +func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { + return Limiter{ + currentBuckets: make(map[string]*atomic.Int64), + bucketSize: maxRequests, + refillTicker: time.NewTicker(refillInterval), + cleanupTicker: time.NewTicker(cleanupInterval), + bucketRefill: refills, + rwLock: &sync.RWMutex{}, + rateChannel: make(chan string), + } +} + +func (l *Limiter) Start() { + go l.Manage() +} + +func (l *Limiter) UpdateCleanupTime(new time.Duration) { + l.cleanupTicker.Reset(new) +} + +func (l *Limiter) Manage() { + for { + select { + case ip := <-l.rateChannel: + if l.AddIfExists(ip) { + break + } + + l.rwLock.Lock() + counter := &atomic.Int64{} + l.currentBuckets[ip] = counter + l.rwLock.Unlock() + case <-l.refillTicker.C: + l.rwLock.RLock() + for ip := range l.currentBuckets { + n := l.currentBuckets[ip].Add(int64(-l.bucketRefill)) + if n < 0 { + l.currentBuckets[ip].Store(0) + n = 0 + } + log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") + } + l.rwLock.RUnlock() + log.Trace().Msg("Refreshed Limits") + case <-l.cleanupTicker.C: + start := time.Now() + l.rwLock.Lock() + deletedBuckets := 0 + for ip := range l.currentBuckets { + if l.currentBuckets[ip].Load() <= 0 { + delete(l.currentBuckets, ip) + deletedBuckets += 1 + } + } + l.rwLock.Unlock() + duration := time.Since(start) + log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") + } + } +} + +// Adds one if ip already exists and returns true +// If ip doesnt yet exist only returns false +func (l *Limiter) AddIfExists(ip string) bool { + l.rwLock.RLock() + defer l.rwLock.RUnlock() + if counter, ok := l.currentBuckets[ip]; ok { + counter.Add(1) + return true + } + return false +} + +func (l *Limiter) RateLimiter(next http.Handler) http.Handler { + log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + addr := strings.Split(r.RemoteAddr, ":")[0] + l.rwLock.RLock() + count, ok := l.currentBuckets[addr] + l.rwLock.RUnlock() + if ok && int(count.Load()) >= l.bucketSize { + hj, ok := w.(http.Hijacker) + if !ok { + r.Body.Close() + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + return + } + conn, _, err := hj.Hijack() + if err != nil { + log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") + } + + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + conn.Close() + return + } + l.rateChannel <- addr + next.ServeHTTP(w, r) + }) +}