Added first acme implementation

This commit is contained in:
Pablu23
2025-07-20 18:34:20 +02:00
parent dc2fe84a96
commit 77a880cee1
8 changed files with 591 additions and 386 deletions

111
acme/acme.go Normal file
View File

@@ -0,0 +1,111 @@
package acme
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"os"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge/http01"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
domainrouter "github.com/pablu23/domain-router"
)
func SetupAcme(config *domainrouter.Config) error {
acme := config.Server.Ssl.Acme
var privateKey *ecdsa.PrivateKey
if _, err := os.Stat(acme.KeyFile); errors.Is(err, os.ErrNotExist) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
err = os.WriteFile(acme.KeyFile, []byte(encode(privateKey)), 0666)
if err != nil {
return err
}
} else {
keyBytes, err := os.ReadFile(acme.KeyFile)
if err != nil {
return err
}
privateKey = decode(string(keyBytes))
}
user := User{
Email: acme.Email,
key: privateKey,
}
leConfig := lego.NewConfig(&user)
leConfig.CADirURL = acme.CADirURL
leConfig.Certificate.KeyType = certcrypto.RSA2048
leConfig.HTTPClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client, err := lego.NewClient(leConfig)
if err != nil {
return err
}
// strconv.Itoa(config.Server.Port)
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "5002"))
if err != nil {
return err
}
err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "5001"))
if err != nil {
return err
}
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
user.Registration = reg
domains := make([]string, 0)
for _, host := range config.Hosts {
domains = append(domains, host.Domains...)
}
request := certificate.ObtainRequest{
Domains: domains,
Bundle: true,
}
certificates, err := client.Certificate.Obtain(request)
if err != nil {
return err
}
fmt.Printf("%#v\n", certificates)
return nil
}
func encode(privateKey *ecdsa.PrivateKey) string {
x509Encoded, _ := x509.MarshalECPrivateKey(privateKey)
pemEncoded := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: x509Encoded})
return string(pemEncoded)
}
func decode(pemEncoded string) *ecdsa.PrivateKey {
block, _ := pem.Decode([]byte(pemEncoded))
x509Encoded := block.Bytes
privateKey, _ := x509.ParseECPrivateKey(x509Encoded)
return privateKey
}

25
acme/user.go Normal file
View File

@@ -0,0 +1,25 @@
package acme
import (
"crypto"
"github.com/go-acme/lego/v4/registration"
)
type User struct {
Email string
Registration *registration.Resource
key crypto.PrivateKey
}
func (u *User) GetEmail() string {
return u.Email
}
func (u *User) GetRegistration() *registration.Resource {
return u.Registration
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}

View File

@@ -1,148 +1,156 @@
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
domainrouter "github.com/pablu23/domain-router"
"github.com/pablu23/domain-router/middleware"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v3"
)
var (
configFileFlag = flag.String("config", "config.yaml", "Path to config file")
)
func main() {
flag.Parse()
config, err := loadConfig(*configFileFlag)
if err != nil {
log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config")
}
setupLogging(config)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
router := domainrouter.New(config, client)
mux := http.NewServeMux()
mux.HandleFunc("/", router.Route)
if config.General.AnnouncePublic {
h, err := url.JoinPath("/", config.General.HealthEndpoint)
if err != nil {
log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path")
h = "/healthz"
}
mux.HandleFunc(h, router.Healthz)
}
pipeline := configureMiddleware(config)
server := http.Server{
Addr: fmt.Sprintf(":%d", config.Server.Port),
Handler: pipeline(mux),
}
if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile)
if err != nil {
return nil, err
}
return &cert, err
},
}
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "")
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
} else {
log.Info().Int("port", config.Server.Port).Msg("Starting server")
err := server.ListenAndServe()
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
}
}
func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
middlewares := make([]middleware.Middleware, 0)
if config.RateLimit.Enabled {
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
if err != nil {
log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker")
}
cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker)
if err != nil {
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
}
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
limiter.Start()
middlewares = append(middlewares, limiter.RateLimiter)
}
if config.Logging.Requests {
middlewares = append(middlewares, middleware.RequestLogger)
}
pipeline := middleware.Pipeline(middlewares...)
return pipeline
}
func setupLogging(config *domainrouter.Config) {
logLevel, err := zerolog.ParseLevel(config.Logging.Level)
if err != nil {
log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level")
}
zerolog.SetGlobalLevel(logLevel)
log.Info().Str("level", config.Logging.Level).Msg("Set logging level")
if config.Logging.Pretty {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
}
if config.Logging.File.Enabled {
var console io.Writer = os.Stderr
if config.Logging.Pretty {
console = zerolog.ConsoleWriter{Out: os.Stderr}
}
log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{
Filename: config.Logging.File.Path,
MaxAge: config.Logging.File.MaxAge,
MaxBackups: config.Logging.File.MaxBackups,
}))
}
}
func loadConfig(path string) (*domainrouter.Config, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var cfg domainrouter.Config
decoder := yaml.NewDecoder(f)
err = decoder.Decode(&cfg)
if err != nil {
return nil, err
}
return &cfg, err
}
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
domainrouter "github.com/pablu23/domain-router"
"github.com/pablu23/domain-router/acme"
"github.com/pablu23/domain-router/middleware"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v3"
)
var (
configFileFlag = flag.String("config", "config.yaml", "Path to config file")
)
func main() {
flag.Parse()
config, err := loadConfig(*configFileFlag)
if err != nil {
log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config")
}
setupLogging(config)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
router := domainrouter.New(config, client)
mux := http.NewServeMux()
mux.HandleFunc("/", router.Route)
if config.General.AnnouncePublic {
h, err := url.JoinPath("/", config.General.HealthEndpoint)
if err != nil {
log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path")
h = "/healthz"
}
mux.HandleFunc(h, router.Healthz)
}
pipeline := configureMiddleware(config)
server := http.Server{
Addr: fmt.Sprintf(":%d", config.Server.Port),
Handler: pipeline(mux),
}
if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile)
if err != nil {
return nil, err
}
return &cert, err
},
}
if config.Server.Ssl.Acme.Enabled {
err := acme.SetupAcme(config)
if err != nil {
log.Fatal().Err(err).Msg("unable to setup acme")
}
}
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "")
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
} else {
log.Info().Int("port", config.Server.Port).Msg("Starting server")
err := server.ListenAndServe()
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
}
}
func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
middlewares := make([]middleware.Middleware, 0)
if config.RateLimit.Enabled {
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
if err != nil {
log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker")
}
cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker)
if err != nil {
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
}
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
limiter.Start()
middlewares = append(middlewares, limiter.RateLimiter)
}
if config.Logging.Requests {
middlewares = append(middlewares, middleware.RequestLogger)
}
pipeline := middleware.Pipeline(middlewares...)
return pipeline
}
func setupLogging(config *domainrouter.Config) {
logLevel, err := zerolog.ParseLevel(config.Logging.Level)
if err != nil {
log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level")
}
zerolog.SetGlobalLevel(logLevel)
log.Info().Str("level", config.Logging.Level).Msg("Set logging level")
if config.Logging.Pretty {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
}
if config.Logging.File.Enabled {
var console io.Writer = os.Stderr
if config.Logging.Pretty {
console = zerolog.ConsoleWriter{Out: os.Stderr}
}
log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{
Filename: config.Logging.File.Path,
MaxAge: config.Logging.File.MaxAge,
MaxBackups: config.Logging.File.MaxBackups,
}))
}
}
func loadConfig(path string) (*domainrouter.Config, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var cfg domainrouter.Config
decoder := yaml.NewDecoder(f)
err = decoder.Decode(&cfg)
if err != nil {
return nil, err
}
return &cfg, err
}

View File

@@ -1,41 +1,47 @@
package domainrouter
type Config struct {
General struct {
AnnouncePublic bool `yaml:"announce"`
HealthEndpoint string `yaml:"healthz"`
} `yaml:"general"`
Server struct {
Port int `yaml:"port"`
Ssl struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
} `yaml:"ssl"`
} `yaml:"server"`
Hosts []struct {
Port int `yaml:"port"`
Remotes []string `yaml:"remotes"`
Domains []string `yaml:"domains"`
Public bool `yaml:"public"`
Secure bool `yaml:"secure"`
} `yaml:"hosts"`
RateLimit struct {
Enabled bool `yaml:"enabled"`
BucketSize int `yaml:"bucketSize"`
RefillTicker string `yaml:"refillTime"`
CleanupTicker string `yaml:"cleanupTime"`
BucketRefill int `yaml:"refillSize"`
} `yaml:"rateLimit"`
Logging struct {
Level string `yaml:"level"`
Pretty bool `yaml:"pretty"`
Requests bool `yaml:"requests"`
File struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
MaxAge int `yaml:"maxAge"`
MaxBackups int `yamls:"maxBackups"`
} `yaml:"file"`
} `yaml:"logging"`
}
package domainrouter
type Config struct {
General struct {
AnnouncePublic bool `yaml:"announce"`
HealthEndpoint string `yaml:"healthz"`
} `yaml:"general"`
Server struct {
Port int `yaml:"port"`
Ssl struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
Acme struct {
Enabled bool `yaml:"enabled"`
Email string `yaml:"email"`
KeyFile string `yaml:"keyFile"`
CADirURL string `yaml:"caDirUrl"`
} `yaml:"acme"`
} `yaml:"ssl"`
} `yaml:"server"`
Hosts []struct {
Port int `yaml:"port"`
Remotes []string `yaml:"remotes"`
Domains []string `yaml:"domains"`
Public bool `yaml:"public"`
Secure bool `yaml:"secure"`
} `yaml:"hosts"`
RateLimit struct {
Enabled bool `yaml:"enabled"`
BucketSize int `yaml:"bucketSize"`
RefillTicker string `yaml:"refillTime"`
CleanupTicker string `yaml:"cleanupTime"`
BucketRefill int `yaml:"refillSize"`
} `yaml:"rateLimit"`
Logging struct {
Level string `yaml:"level"`
Pretty bool `yaml:"pretty"`
Requests bool `yaml:"requests"`
File struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
MaxAge int `yaml:"maxAge"`
MaxBackups int `yamls:"maxBackups"`
} `yaml:"file"`
} `yaml:"logging"`
}

View File

@@ -1,70 +1,75 @@
server:
port: 443
ssl:
enabled: true
certFile: server.crt
keyFile: server.key
logging:
level: debug
# Pretty print for human consumption otherwise json
pretty: true
# Log incoming requests
requests: true
# Log to file aswell as stderr
file:
enabled: false
maxAge: 14
maxBackups: 10
path: ~/logs/router
rateLimit:
enabled: true
# How many requests per ip adress are allowed
bucketSize: 50
# How many requests per ip address are refilled
refillSize: 50
# How often requests per ip address are refilled
refillTime: 30s
# How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up)
cleanupTime: 45s
hosts:
# Remote address to request
- remotes:
- localhost
- 192.168.2.154
# Port on which to request
port: 8181
# Health check if announce is true
public: true
# Domains which get redirected to host
domains:
- localhost
- test.localhost
- remotes:
- localhost
port: 8282
public: false
domains:
- private.localhost
- remotes:
- www.google.com
- localhost
port: 443
public: false
# Uses https under the hood to communicate with the remote host
secure: true
domains:
- google.localhost
general:
# Expose health endpoint, that requests health endpoints from hosts which are public
announce: true
# Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint
healthz: healthz
server:
port: 443
ssl:
enabled: true
certFile: server.crt
keyFile: server.key
acme:
enabled: true
email: me@pablu.de
keyFile: userKey.key
caDirUrl: https://192.168.2.154:14000/dir
logging:
level: debug
# Pretty print for human consumption otherwise json
pretty: true
# Log incoming requests
requests: true
# Log to file aswell as stderr
file:
enabled: false
maxAge: 14
maxBackups: 10
path: ~/logs/router
rateLimit:
enabled: true
# How many requests per ip adress are allowed
bucketSize: 50
# How many requests per ip address are refilled
refillSize: 50
# How often requests per ip address are refilled
refillTime: 30s
# How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up)
cleanupTime: 45s
hosts:
# Remote address to request
- remotes:
- localhost
- 192.168.2.154
# Port on which to request
port: 8181
# Health check if announce is true
public: true
# Domains which get redirected to host
domains:
- localhost
- test.localhost
- remotes:
- localhost
port: 8282
public: false
domains:
- private.localhost
# - remotes:
# - www.google.com
# - localhost
# port: 443
# public: false
# # Uses https under the hood to communicate with the remote host
# secure: true
# domains:
# - google.localhost
general:
# Expose health endpoint, that requests health endpoints from hosts which are public
announce: true
# Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint
healthz: healthz

26
go.mod
View File

@@ -1,14 +1,32 @@
module github.com/pablu23/domain-router
go 1.23
go 1.23.0
require github.com/rs/zerolog v1.33.0
toolchain go1.24.4
require github.com/rs/zerolog v1.34.0
require (
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/miekg/dns v1.1.67 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/tools v0.34.0 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
)
require (
github.com/go-acme/lego v2.7.2+incompatible
github.com/go-acme/lego/v4 v4.24.0
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/urfave/negroni v1.0.0
golang.org/x/sys v0.12.0 // indirect
golang.org/x/sys v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)

33
go.sum
View File

@@ -1,22 +1,55 @@
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/go-acme/lego v2.7.2+incompatible h1:ThhpPBgf6oa9X/vRd0kEmWOsX7+vmYdckmGZSb+FEp0=
github.com/go-acme/lego v2.7.2+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-acme/lego/v4 v4.24.0 h1:pe0q49JKxfSGEP3lkgkMVQrZM1KbD+e0dpJ2McYsiVw=
github.com/go-acme/lego/v4 v4.24.0/go.mod h1:hkstZY6D0jylIrZbuNmEQrWQxTIfaJH7prwaWvKDOjw=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0=
github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,123 +1,122 @@
package middleware
import (
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
)
type Limiter struct {
currentBuckets map[string]*atomic.Int64
bucketSize int
refillTicker *time.Ticker
cleanupTicker *time.Ticker
bucketRefill int
rwLock *sync.RWMutex
rateChannel chan string
}
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
return Limiter{
currentBuckets: make(map[string]*atomic.Int64),
bucketSize: maxRequests,
refillTicker: time.NewTicker(refillInterval),
cleanupTicker: time.NewTicker(cleanupInterval),
bucketRefill: refills,
rwLock: &sync.RWMutex{},
rateChannel: make(chan string),
}
}
func (l *Limiter) Start() {
go l.Manage()
return
}
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
l.cleanupTicker.Reset(new)
}
func (l *Limiter) Manage() {
for {
select {
case ip := <-l.rateChannel:
if l.AddIfExists(ip) {
break
}
l.rwLock.Lock()
counter := &atomic.Int64{}
l.currentBuckets[ip] = counter
l.rwLock.Unlock()
case <-l.refillTicker.C:
l.rwLock.RLock()
for ip := range l.currentBuckets {
n := l.currentBuckets[ip].Add(int64(-l.bucketRefill))
if n < 0 {
l.currentBuckets[ip].Store(0)
n = 0
}
log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
}
l.rwLock.RUnlock()
log.Trace().Msg("Refreshed Limits")
case <-l.cleanupTicker.C:
start := time.Now()
l.rwLock.Lock()
deletedBuckets := 0
for ip := range l.currentBuckets {
if l.currentBuckets[ip].Load() <= 0 {
delete(l.currentBuckets, ip)
deletedBuckets += 1
}
}
l.rwLock.Unlock()
duration := time.Since(start)
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
}
}
}
// Adds one if ip already exists and returns true
// If ip doesnt yet exist only returns false
func (l *Limiter) AddIfExists(ip string) bool {
l.rwLock.RLock()
defer l.rwLock.RUnlock()
if counter, ok := l.currentBuckets[ip]; ok {
counter.Add(1)
return true
}
return false
}
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
addr := strings.Split(r.RemoteAddr, ":")[0]
l.rwLock.RLock()
count, ok := l.currentBuckets[addr]
l.rwLock.RUnlock()
if ok && int(count.Load()) >= l.bucketSize {
hj, ok := w.(http.Hijacker)
if !ok {
r.Body.Close()
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
return
}
conn, _, err := hj.Hijack()
if err != nil {
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
}
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
conn.Close()
return
}
l.rateChannel <- addr
next.ServeHTTP(w, r)
})
}
package middleware
import (
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
)
type Limiter struct {
currentBuckets map[string]*atomic.Int64
bucketSize int
refillTicker *time.Ticker
cleanupTicker *time.Ticker
bucketRefill int
rwLock *sync.RWMutex
rateChannel chan string
}
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
return Limiter{
currentBuckets: make(map[string]*atomic.Int64),
bucketSize: maxRequests,
refillTicker: time.NewTicker(refillInterval),
cleanupTicker: time.NewTicker(cleanupInterval),
bucketRefill: refills,
rwLock: &sync.RWMutex{},
rateChannel: make(chan string),
}
}
func (l *Limiter) Start() {
go l.Manage()
}
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
l.cleanupTicker.Reset(new)
}
func (l *Limiter) Manage() {
for {
select {
case ip := <-l.rateChannel:
if l.AddIfExists(ip) {
break
}
l.rwLock.Lock()
counter := &atomic.Int64{}
l.currentBuckets[ip] = counter
l.rwLock.Unlock()
case <-l.refillTicker.C:
l.rwLock.RLock()
for ip := range l.currentBuckets {
n := l.currentBuckets[ip].Add(int64(-l.bucketRefill))
if n < 0 {
l.currentBuckets[ip].Store(0)
n = 0
}
log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
}
l.rwLock.RUnlock()
log.Trace().Msg("Refreshed Limits")
case <-l.cleanupTicker.C:
start := time.Now()
l.rwLock.Lock()
deletedBuckets := 0
for ip := range l.currentBuckets {
if l.currentBuckets[ip].Load() <= 0 {
delete(l.currentBuckets, ip)
deletedBuckets += 1
}
}
l.rwLock.Unlock()
duration := time.Since(start)
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
}
}
}
// Adds one if ip already exists and returns true
// If ip doesnt yet exist only returns false
func (l *Limiter) AddIfExists(ip string) bool {
l.rwLock.RLock()
defer l.rwLock.RUnlock()
if counter, ok := l.currentBuckets[ip]; ok {
counter.Add(1)
return true
}
return false
}
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
addr := strings.Split(r.RemoteAddr, ":")[0]
l.rwLock.RLock()
count, ok := l.currentBuckets[addr]
l.rwLock.RUnlock()
if ok && int(count.Load()) >= l.bucketSize {
hj, ok := w.(http.Hijacker)
if !ok {
r.Body.Close()
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
return
}
conn, _, err := hj.Hijack()
if err != nil {
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
}
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
conn.Close()
return
}
l.rateChannel <- addr
next.ServeHTTP(w, r)
})
}