Added first acme implementation

This commit is contained in:
Pablu23
2025-07-20 18:34:20 +02:00
parent dc2fe84a96
commit 77a880cee1
8 changed files with 591 additions and 386 deletions

111
acme/acme.go Normal file
View File

@@ -0,0 +1,111 @@
package acme
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"os"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge/http01"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
domainrouter "github.com/pablu23/domain-router"
)
func SetupAcme(config *domainrouter.Config) error {
acme := config.Server.Ssl.Acme
var privateKey *ecdsa.PrivateKey
if _, err := os.Stat(acme.KeyFile); errors.Is(err, os.ErrNotExist) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
err = os.WriteFile(acme.KeyFile, []byte(encode(privateKey)), 0666)
if err != nil {
return err
}
} else {
keyBytes, err := os.ReadFile(acme.KeyFile)
if err != nil {
return err
}
privateKey = decode(string(keyBytes))
}
user := User{
Email: acme.Email,
key: privateKey,
}
leConfig := lego.NewConfig(&user)
leConfig.CADirURL = acme.CADirURL
leConfig.Certificate.KeyType = certcrypto.RSA2048
leConfig.HTTPClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client, err := lego.NewClient(leConfig)
if err != nil {
return err
}
// strconv.Itoa(config.Server.Port)
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "5002"))
if err != nil {
return err
}
err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "5001"))
if err != nil {
return err
}
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
user.Registration = reg
domains := make([]string, 0)
for _, host := range config.Hosts {
domains = append(domains, host.Domains...)
}
request := certificate.ObtainRequest{
Domains: domains,
Bundle: true,
}
certificates, err := client.Certificate.Obtain(request)
if err != nil {
return err
}
fmt.Printf("%#v\n", certificates)
return nil
}
func encode(privateKey *ecdsa.PrivateKey) string {
x509Encoded, _ := x509.MarshalECPrivateKey(privateKey)
pemEncoded := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: x509Encoded})
return string(pemEncoded)
}
func decode(pemEncoded string) *ecdsa.PrivateKey {
block, _ := pem.Decode([]byte(pemEncoded))
x509Encoded := block.Bytes
privateKey, _ := x509.ParseECPrivateKey(x509Encoded)
return privateKey
}

25
acme/user.go Normal file
View File

@@ -0,0 +1,25 @@
package acme
import (
"crypto"
"github.com/go-acme/lego/v4/registration"
)
type User struct {
Email string
Registration *registration.Resource
key crypto.PrivateKey
}
func (u *User) GetEmail() string {
return u.Email
}
func (u *User) GetRegistration() *registration.Resource {
return u.Registration
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}

View File

@@ -1,148 +1,156 @@
package main package main
import ( import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"time" "time"
domainrouter "github.com/pablu23/domain-router" domainrouter "github.com/pablu23/domain-router"
"github.com/pablu23/domain-router/middleware" "github.com/pablu23/domain-router/acme"
"github.com/rs/zerolog" "github.com/pablu23/domain-router/middleware"
"github.com/rs/zerolog/log" "github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v3" "gopkg.in/natefinch/lumberjack.v2"
) "gopkg.in/yaml.v3"
)
var (
configFileFlag = flag.String("config", "config.yaml", "Path to config file") var (
) configFileFlag = flag.String("config", "config.yaml", "Path to config file")
)
func main() {
flag.Parse() func main() {
flag.Parse()
config, err := loadConfig(*configFileFlag)
if err != nil { config, err := loadConfig(*configFileFlag)
log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") if err != nil {
} log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config")
}
setupLogging(config)
client := &http.Client{ setupLogging(config)
CheckRedirect: func(req *http.Request, via []*http.Request) error { client := &http.Client{
return http.ErrUseLastResponse CheckRedirect: func(req *http.Request, via []*http.Request) error {
}, return http.ErrUseLastResponse
} },
}
router := domainrouter.New(config, client)
mux := http.NewServeMux() router := domainrouter.New(config, client)
mux.HandleFunc("/", router.Route) mux := http.NewServeMux()
mux.HandleFunc("/", router.Route)
if config.General.AnnouncePublic {
h, err := url.JoinPath("/", config.General.HealthEndpoint) if config.General.AnnouncePublic {
if err != nil { h, err := url.JoinPath("/", config.General.HealthEndpoint)
log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path") if err != nil {
h = "/healthz" log.Error().Err(err).Str("endpoint", config.General.HealthEndpoint).Msg("Could not create endpoint path")
} h = "/healthz"
mux.HandleFunc(h, router.Healthz) }
} mux.HandleFunc(h, router.Healthz)
}
pipeline := configureMiddleware(config)
pipeline := configureMiddleware(config)
server := http.Server{
Addr: fmt.Sprintf(":%d", config.Server.Port), server := http.Server{
Handler: pipeline(mux), Addr: fmt.Sprintf(":%d", config.Server.Port),
} Handler: pipeline(mux),
}
if config.Server.Ssl.Enabled {
server.TLSConfig = &tls.Config{ if config.Server.Ssl.Enabled {
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { server.TLSConfig = &tls.Config{
cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile) GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
if err != nil { cert, err := tls.LoadX509KeyPair(config.Server.Ssl.CertFile, config.Server.Ssl.KeyFile)
return nil, err if err != nil {
} return nil, err
return &cert, err }
}, return &cert, err
} },
}
log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "") if config.Server.Ssl.Acme.Enabled {
log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server") err := acme.SetupAcme(config)
} else { if err != nil {
log.Info().Int("port", config.Server.Port).Msg("Starting server") log.Fatal().Err(err).Msg("unable to setup acme")
err := server.ListenAndServe() }
log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server") }
}
} log.Info().Int("port", config.Server.Port).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Msg("Starting server")
err := server.ListenAndServeTLS("", "")
func configureMiddleware(config *domainrouter.Config) middleware.Middleware { log.Fatal().Err(err).Str("cert", config.Server.Ssl.CertFile).Str("key", config.Server.Ssl.KeyFile).Int("port", config.Server.Port).Msg("Could not start server")
middlewares := make([]middleware.Middleware, 0) } else {
log.Info().Int("port", config.Server.Port).Msg("Starting server")
if config.RateLimit.Enabled { err := server.ListenAndServe()
refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker) log.Fatal().Err(err).Int("port", config.Server.Port).Msg("Could not start server")
if err != nil { }
log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker") }
}
func configureMiddleware(config *domainrouter.Config) middleware.Middleware {
cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker) middlewares := make([]middleware.Middleware, 0)
if err != nil {
log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker") if config.RateLimit.Enabled {
} refillTicker, err := time.ParseDuration(config.RateLimit.RefillTicker)
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker) if err != nil {
limiter.Start() log.Fatal().Err(err).Str("refill", config.RateLimit.RefillTicker).Msg("Could not parse refill Ticker")
middlewares = append(middlewares, limiter.RateLimiter) }
}
cleanupTicker, err := time.ParseDuration(config.RateLimit.CleanupTicker)
if config.Logging.Requests { if err != nil {
middlewares = append(middlewares, middleware.RequestLogger) log.Fatal().Err(err).Str("cleanup", config.RateLimit.CleanupTicker).Msg("Could not parse cleanup Ticker")
} }
limiter := middleware.NewLimiter(config.RateLimit.BucketSize, config.RateLimit.BucketRefill, refillTicker, cleanupTicker)
pipeline := middleware.Pipeline(middlewares...) limiter.Start()
return pipeline middlewares = append(middlewares, limiter.RateLimiter)
} }
func setupLogging(config *domainrouter.Config) { if config.Logging.Requests {
logLevel, err := zerolog.ParseLevel(config.Logging.Level) middlewares = append(middlewares, middleware.RequestLogger)
if err != nil { }
log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level")
} pipeline := middleware.Pipeline(middlewares...)
return pipeline
zerolog.SetGlobalLevel(logLevel) }
log.Info().Str("level", config.Logging.Level).Msg("Set logging level")
if config.Logging.Pretty { func setupLogging(config *domainrouter.Config) {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) logLevel, err := zerolog.ParseLevel(config.Logging.Level)
} if err != nil {
log.Fatal().Err(err).Str("level", config.Logging.Level).Msg("Could not parse string to level")
if config.Logging.File.Enabled { }
var console io.Writer = os.Stderr
if config.Logging.Pretty { zerolog.SetGlobalLevel(logLevel)
console = zerolog.ConsoleWriter{Out: os.Stderr} log.Info().Str("level", config.Logging.Level).Msg("Set logging level")
} if config.Logging.Pretty {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ }
Filename: config.Logging.File.Path,
MaxAge: config.Logging.File.MaxAge, if config.Logging.File.Enabled {
MaxBackups: config.Logging.File.MaxBackups, var console io.Writer = os.Stderr
})) if config.Logging.Pretty {
} console = zerolog.ConsoleWriter{Out: os.Stderr}
} }
func loadConfig(path string) (*domainrouter.Config, error) { log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{
f, err := os.Open(path) Filename: config.Logging.File.Path,
if err != nil { MaxAge: config.Logging.File.MaxAge,
return nil, err MaxBackups: config.Logging.File.MaxBackups,
} }))
defer f.Close() }
}
var cfg domainrouter.Config
decoder := yaml.NewDecoder(f) func loadConfig(path string) (*domainrouter.Config, error) {
err = decoder.Decode(&cfg) f, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close()
return &cfg, err
} var cfg domainrouter.Config
decoder := yaml.NewDecoder(f)
err = decoder.Decode(&cfg)
if err != nil {
return nil, err
}
return &cfg, err
}

View File

@@ -1,41 +1,47 @@
package domainrouter package domainrouter
type Config struct { type Config struct {
General struct { General struct {
AnnouncePublic bool `yaml:"announce"` AnnouncePublic bool `yaml:"announce"`
HealthEndpoint string `yaml:"healthz"` HealthEndpoint string `yaml:"healthz"`
} `yaml:"general"` } `yaml:"general"`
Server struct { Server struct {
Port int `yaml:"port"` Port int `yaml:"port"`
Ssl struct { Ssl struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
CertFile string `yaml:"certFile"` CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"` KeyFile string `yaml:"keyFile"`
} `yaml:"ssl"` Acme struct {
} `yaml:"server"` Enabled bool `yaml:"enabled"`
Hosts []struct { Email string `yaml:"email"`
Port int `yaml:"port"` KeyFile string `yaml:"keyFile"`
Remotes []string `yaml:"remotes"` CADirURL string `yaml:"caDirUrl"`
Domains []string `yaml:"domains"` } `yaml:"acme"`
Public bool `yaml:"public"` } `yaml:"ssl"`
Secure bool `yaml:"secure"` } `yaml:"server"`
} `yaml:"hosts"` Hosts []struct {
RateLimit struct { Port int `yaml:"port"`
Enabled bool `yaml:"enabled"` Remotes []string `yaml:"remotes"`
BucketSize int `yaml:"bucketSize"` Domains []string `yaml:"domains"`
RefillTicker string `yaml:"refillTime"` Public bool `yaml:"public"`
CleanupTicker string `yaml:"cleanupTime"` Secure bool `yaml:"secure"`
BucketRefill int `yaml:"refillSize"` } `yaml:"hosts"`
} `yaml:"rateLimit"` RateLimit struct {
Logging struct { Enabled bool `yaml:"enabled"`
Level string `yaml:"level"` BucketSize int `yaml:"bucketSize"`
Pretty bool `yaml:"pretty"` RefillTicker string `yaml:"refillTime"`
Requests bool `yaml:"requests"` CleanupTicker string `yaml:"cleanupTime"`
File struct { BucketRefill int `yaml:"refillSize"`
Enabled bool `yaml:"enabled"` } `yaml:"rateLimit"`
Path string `yaml:"path"` Logging struct {
MaxAge int `yaml:"maxAge"` Level string `yaml:"level"`
MaxBackups int `yamls:"maxBackups"` Pretty bool `yaml:"pretty"`
} `yaml:"file"` Requests bool `yaml:"requests"`
} `yaml:"logging"` File struct {
} Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
MaxAge int `yaml:"maxAge"`
MaxBackups int `yamls:"maxBackups"`
} `yaml:"file"`
} `yaml:"logging"`
}

View File

@@ -1,70 +1,75 @@
server: server:
port: 443 port: 443
ssl: ssl:
enabled: true enabled: true
certFile: server.crt certFile: server.crt
keyFile: server.key keyFile: server.key
acme:
enabled: true
logging: email: me@pablu.de
level: debug keyFile: userKey.key
# Pretty print for human consumption otherwise json caDirUrl: https://192.168.2.154:14000/dir
pretty: true
# Log incoming requests
requests: true logging:
# Log to file aswell as stderr level: debug
file: # Pretty print for human consumption otherwise json
enabled: false pretty: true
maxAge: 14 # Log incoming requests
maxBackups: 10 requests: true
path: ~/logs/router # Log to file aswell as stderr
file:
enabled: false
rateLimit: maxAge: 14
enabled: true maxBackups: 10
# How many requests per ip adress are allowed path: ~/logs/router
bucketSize: 50
# How many requests per ip address are refilled
refillSize: 50 rateLimit:
# How often requests per ip address are refilled enabled: true
refillTime: 30s # How many requests per ip adress are allowed
# How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up) bucketSize: 50
cleanupTime: 45s # How many requests per ip address are refilled
refillSize: 50
# How often requests per ip address are refilled
hosts: refillTime: 30s
# Remote address to request # How often Ip Addresses get cleaned up (only ip addresses with max allowed requests are cleaned up)
- remotes: cleanupTime: 45s
- localhost
- 192.168.2.154
# Port on which to request hosts:
port: 8181 # Remote address to request
# Health check if announce is true - remotes:
public: true - localhost
# Domains which get redirected to host - 192.168.2.154
domains: # Port on which to request
- localhost port: 8181
- test.localhost # Health check if announce is true
public: true
- remotes: # Domains which get redirected to host
- localhost domains:
port: 8282 - localhost
public: false - test.localhost
domains:
- private.localhost - remotes:
- localhost
- remotes: port: 8282
- www.google.com public: false
- localhost domains:
port: 443 - private.localhost
public: false
# Uses https under the hood to communicate with the remote host # - remotes:
secure: true # - www.google.com
domains: # - localhost
- google.localhost # port: 443
# public: false
general: # # Uses https under the hood to communicate with the remote host
# Expose health endpoint, that requests health endpoints from hosts which are public # secure: true
announce: true # domains:
# Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint # - google.localhost
healthz: healthz
general:
# Expose health endpoint, that requests health endpoints from hosts which are public
announce: true
# Path to health endpoint on router, is allowed to conflict with hosts, but overwrites specific host endpoint
healthz: healthz

26
go.mod
View File

@@ -1,14 +1,32 @@
module github.com/pablu23/domain-router module github.com/pablu23/domain-router
go 1.23 go 1.23.0
require github.com/rs/zerolog v1.33.0 toolchain go1.24.4
require github.com/rs/zerolog v1.34.0
require ( require (
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/miekg/dns v1.1.67 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/tools v0.34.0 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
)
require (
github.com/go-acme/lego v2.7.2+incompatible
github.com/go-acme/lego/v4 v4.24.0
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/urfave/negroni v1.0.0 github.com/urfave/negroni v1.0.0
golang.org/x/sys v0.12.0 // indirect golang.org/x/sys v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )

33
go.sum
View File

@@ -1,22 +1,55 @@
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/go-acme/lego v2.7.2+incompatible h1:ThhpPBgf6oa9X/vRd0kEmWOsX7+vmYdckmGZSb+FEp0=
github.com/go-acme/lego v2.7.2+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-acme/lego/v4 v4.24.0 h1:pe0q49JKxfSGEP3lkgkMVQrZM1KbD+e0dpJ2McYsiVw=
github.com/go-acme/lego/v4 v4.24.0/go.mod h1:hkstZY6D0jylIrZbuNmEQrWQxTIfaJH7prwaWvKDOjw=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0=
github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,123 +1,122 @@
package middleware package middleware
import ( import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Limiter struct { type Limiter struct {
currentBuckets map[string]*atomic.Int64 currentBuckets map[string]*atomic.Int64
bucketSize int bucketSize int
refillTicker *time.Ticker refillTicker *time.Ticker
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
bucketRefill int bucketRefill int
rwLock *sync.RWMutex rwLock *sync.RWMutex
rateChannel chan string rateChannel chan string
} }
func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter { func NewLimiter(maxRequests int, refills int, refillInterval time.Duration, cleanupInterval time.Duration) Limiter {
return Limiter{ return Limiter{
currentBuckets: make(map[string]*atomic.Int64), currentBuckets: make(map[string]*atomic.Int64),
bucketSize: maxRequests, bucketSize: maxRequests,
refillTicker: time.NewTicker(refillInterval), refillTicker: time.NewTicker(refillInterval),
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
bucketRefill: refills, bucketRefill: refills,
rwLock: &sync.RWMutex{}, rwLock: &sync.RWMutex{},
rateChannel: make(chan string), rateChannel: make(chan string),
} }
} }
func (l *Limiter) Start() { func (l *Limiter) Start() {
go l.Manage() go l.Manage()
return }
}
func (l *Limiter) UpdateCleanupTime(new time.Duration) {
func (l *Limiter) UpdateCleanupTime(new time.Duration) { l.cleanupTicker.Reset(new)
l.cleanupTicker.Reset(new) }
}
func (l *Limiter) Manage() {
func (l *Limiter) Manage() { for {
for { select {
select { case ip := <-l.rateChannel:
case ip := <-l.rateChannel: if l.AddIfExists(ip) {
if l.AddIfExists(ip) { break
break }
}
l.rwLock.Lock()
l.rwLock.Lock() counter := &atomic.Int64{}
counter := &atomic.Int64{} l.currentBuckets[ip] = counter
l.currentBuckets[ip] = counter l.rwLock.Unlock()
l.rwLock.Unlock() case <-l.refillTicker.C:
case <-l.refillTicker.C: l.rwLock.RLock()
l.rwLock.RLock() for ip := range l.currentBuckets {
for ip := range l.currentBuckets { n := l.currentBuckets[ip].Add(int64(-l.bucketRefill))
n := l.currentBuckets[ip].Add(int64(-l.bucketRefill)) if n < 0 {
if n < 0 { l.currentBuckets[ip].Store(0)
l.currentBuckets[ip].Store(0) n = 0
n = 0 }
} log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit")
log.Trace().Int64("bucket", n).Str("remote", ip).Msg("Updated limit") }
} l.rwLock.RUnlock()
l.rwLock.RUnlock() log.Trace().Msg("Refreshed Limits")
log.Trace().Msg("Refreshed Limits") case <-l.cleanupTicker.C:
case <-l.cleanupTicker.C: start := time.Now()
start := time.Now() l.rwLock.Lock()
l.rwLock.Lock() deletedBuckets := 0
deletedBuckets := 0 for ip := range l.currentBuckets {
for ip := range l.currentBuckets { if l.currentBuckets[ip].Load() <= 0 {
if l.currentBuckets[ip].Load() <= 0 { delete(l.currentBuckets, ip)
delete(l.currentBuckets, ip) deletedBuckets += 1
deletedBuckets += 1 }
} }
} l.rwLock.Unlock()
l.rwLock.Unlock() duration := time.Since(start)
duration := time.Since(start) log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets")
log.Debug().Str("duration", duration.String()).Int("deleted_buckets", deletedBuckets).Msg("Cleaned up Buckets") }
} }
} }
}
// Adds one if ip already exists and returns true
// Adds one if ip already exists and returns true // If ip doesnt yet exist only returns false
// If ip doesnt yet exist only returns false func (l *Limiter) AddIfExists(ip string) bool {
func (l *Limiter) AddIfExists(ip string) bool { l.rwLock.RLock()
l.rwLock.RLock() defer l.rwLock.RUnlock()
defer l.rwLock.RUnlock() if counter, ok := l.currentBuckets[ip]; ok {
if counter, ok := l.currentBuckets[ip]; ok { counter.Add(1)
counter.Add(1) return true
return true }
} return false
return false }
}
func (l *Limiter) RateLimiter(next http.Handler) http.Handler {
func (l *Limiter) RateLimiter(next http.Handler) http.Handler { log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits")
log.Info().Int("bucket_size", l.bucketSize).Int("bucket_refill", l.bucketRefill).Msg("Enabling Ratelimits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { addr := strings.Split(r.RemoteAddr, ":")[0]
addr := strings.Split(r.RemoteAddr, ":")[0] l.rwLock.RLock()
l.rwLock.RLock() count, ok := l.currentBuckets[addr]
count, ok := l.currentBuckets[addr] l.rwLock.RUnlock()
l.rwLock.RUnlock() if ok && int(count.Load()) >= l.bucketSize {
if ok && int(count.Load()) >= l.bucketSize { hj, ok := w.(http.Hijacker)
hj, ok := w.(http.Hijacker) if !ok {
if !ok { r.Body.Close()
r.Body.Close() log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") return
return }
} conn, _, err := hj.Hijack()
conn, _, err := hj.Hijack() if err != nil {
if err != nil { log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection")
log.Error().Err(err).Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Could not hijack connection") }
}
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited")
log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") conn.Close()
conn.Close() return
return }
} l.rateChannel <- addr
l.rateChannel <- addr next.ServeHTTP(w, r)
next.ServeHTTP(w, r) })
}) }
}