From 1dd90bbe888b770051dabd40c416a8d8bcca2f3a Mon Sep 17 00:00:00 2001 From: Pablu23 Date: Sun, 7 Jul 2024 18:46:49 +0200 Subject: [PATCH] Change project to framework with example implementation --- .gitignore | 1 + logging-middleware.go | 17 +--- main.go | 223 ------------------------------------------ rate-limit.go | 95 ++++++++++++++++++ router.go | 142 +++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 239 deletions(-) delete mode 100644 main.go create mode 100644 rate-limit.go create mode 100644 router.go diff --git a/.gitignore b/.gitignore index be870b4..fa92c96 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.crt *.key +domain-router diff --git a/logging-middleware.go b/logging-middleware.go index e678758..a09da3b 100644 --- a/logging-middleware.go +++ b/logging-middleware.go @@ -1,4 +1,4 @@ -package main +package domainrouter import ( "net/http" @@ -8,25 +8,10 @@ import ( "github.com/urfave/negroni" ) -// type loggingResponseWriter struct { -// http.ResponseWriter -// statusCode int -// } -// -// func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { -// return &loggingResponseWriter{w, http.StatusOK} -// } -// -// func (lrw *loggingResponseWriter) WriteHeader(code int) { -// lrw.statusCode = code -// lrw.ResponseWriter.WriteHeader(code) -// } - func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - // lrw := NewLoggingResponseWriter(w) lrw := negroni.NewResponseWriter(w) next.ServeHTTP(lrw, r) diff --git a/main.go b/main.go deleted file mode 100644 index 90730ba..0000000 --- a/main.go +++ /dev/null @@ -1,223 +0,0 @@ -package main - -import ( - "bufio" - "crypto/tls" - "errors" - "flag" - "fmt" - "io" - "net/http" - "net/http/httputil" - "os" - "strconv" - "strings" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "gopkg.in/natefinch/lumberjack.v2" -) - -var ( - configFileFlag = flag.String("config", "domains.conf", "Path to Domain config file") - certFlag = flag.String("cert", "", "Path to cert file") - keyFlag = flag.String("key", "", "Path to key file") - portFlag = flag.Int("port", 80, "Port") - prettyLogsFlag = flag.Bool("pretty", false, "Pretty print? Default is json") - logPathFlag = flag.String("log", "", "Path to logfile, default is stderr") - logLevelFlag = flag.String("log-level", "info", "Log Level") -) - -func main() { - flag.Parse() - - setupLogging() - - domains, err := loadConfig(*configFileFlag) - if err != nil { - log.Fatal().Err(err).Str("path", *configFileFlag).Msg("Could not load Config") - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - port, ok := domains[r.Host] - if !ok { - w.WriteHeader(http.StatusOK) - return - } - - if !dumpRequest(w, r) { - return - } - - subUrlPath := r.URL.RequestURI() - req, err := http.NewRequest(r.Method, fmt.Sprintf("http://localhost:%d%s", port, subUrlPath), r.Body) - if err != nil { - log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not create request") - w.WriteHeader(http.StatusInternalServerError) - return - } - - for name, values := range r.Header { - for _, value := range values { - req.Header.Set(name, value) - } - } - - for _, cookie := range r.Cookies() { - req.AddCookie(cookie) - } - - if !dumpRequest(w, req) { - return - } - - res, err := client.Do(req) - if err != nil { - log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not complete request") - w.WriteHeader(http.StatusInternalServerError) - return - } - - cookies := res.Cookies() - for _, cookie := range cookies { - // fmt.Printf("Setting cookie, Name: %s, Value: %s\n", cookie.Name, cookie.Value) - http.SetCookie(w, cookie) - } - - if !dumpResponse(w, res) { - return - } - - if loc, err := res.Location(); !errors.Is(err, http.ErrNoLocation) { - http.Redirect(w, r, loc.RequestURI(), http.StatusFound) - } else { - for name, values := range res.Header { - for _, value := range values { - w.Header().Set(name, value) - } - } - w.WriteHeader(res.StatusCode) - - body, err := io.ReadAll(res.Body) - defer res.Body.Close() - if err != nil { - log.Error().Err(err).Msg("Could not read body") - w.WriteHeader(http.StatusInternalServerError) - return - } - - _, err = w.Write(body) - if err != nil { - log.Error().Err(err).Msg("Could not write body") - w.WriteHeader(http.StatusInternalServerError) - return - } - } - }) - - server := http.Server{ - Addr: fmt.Sprintf(":%d", *portFlag), - Handler: RequestLogger(mux), - } - - if *certFlag != "" && *keyFlag != "" { - server.TLSConfig = &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(*certFlag, *keyFlag) - if err != nil { - return nil, err - } - return &cert, err - }, - } - log.Info().Int("port", *portFlag).Str("cert", *certFlag).Str("key", *keyFlag).Msg("Starting server") - err := server.ListenAndServeTLS("", "") - log.Fatal().Err(err).Str("cert", *certFlag).Str("key", *keyFlag).Int("port", *portFlag).Msg("Could not start server") - } else { - log.Info().Int("port", *portFlag).Msg("Starting server") - err := server.ListenAndServe() - log.Fatal().Err(err).Int("port", *portFlag).Msg("Could not start server") - } -} - -func setupLogging() { - logLevel, err := zerolog.ParseLevel(*logLevelFlag) - if err != nil { - log.Fatal().Err(err).Str("level", *logLevelFlag).Msg("Could not parse string to level") - } - - zerolog.SetGlobalLevel(logLevel) - if *prettyLogsFlag { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - } - - if *logPathFlag != "" { - var console io.Writer = os.Stderr - if *prettyLogsFlag { - console = zerolog.ConsoleWriter{Out: os.Stderr} - } - log.Logger = log.Output(zerolog.MultiLevelWriter(console, &lumberjack.Logger{ - Filename: *logPathFlag, - MaxAge: 14, - MaxBackups: 10, - })) - } -} - -func dumpRequest(w http.ResponseWriter, r *http.Request) bool { - if e := log.Debug(); e.Enabled() && r.Method == "POST" { - rDump, err := httputil.DumpRequest(r, true) - if err != nil { - log.Error().Err(err).Msg("Could not dump request") - w.WriteHeader(http.StatusInternalServerError) - return false - } - log.Debug().Str("dump", string(rDump)).Send() - } - return true -} - -func dumpResponse(w http.ResponseWriter, r *http.Response) bool { - if e := log.Trace(); e.Enabled() { - dump, err := httputil.DumpResponse(r, true) - if err != nil { - log.Error().Err(err).Msg("Could not dump response") - w.WriteHeader(http.StatusInternalServerError) - return false - } - log.Trace().Str("dump", string(dump)).Send() - } - return true -} - -func loadConfig(path string) (map[string]int, error) { - file, err := os.Open(path) - if err != nil { - return nil, err - } - scanner := bufio.NewScanner(file) - scanner.Split(bufio.ScanLines) - - m := make(map[string]int) - for scanner.Scan() { - line := scanner.Text() - params := strings.Split(line, ";") - if len(params) <= 1 { - return nil, errors.New("Line does not contain enough Parameters") - } - port, err := strconv.Atoi(params[1]) - if err != nil { - return nil, err - } - m[params[0]] = port - } - - return m, nil -} diff --git a/rate-limit.go b/rate-limit.go new file mode 100644 index 0000000..479128a --- /dev/null +++ b/rate-limit.go @@ -0,0 +1,95 @@ +package domainrouter + +import ( + "net/http" + "strings" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +type Limiter struct { + current map[string]int + max int + ticker *time.Ticker + refill int + m *sync.RWMutex + c chan string +} + +func NewLimiter(maxRequests int, refills int, refillInterval time.Duration) Limiter { + return Limiter{ + current: make(map[string]int), + max: maxRequests, + ticker: time.NewTicker(refillInterval), + refill: refills, + m: &sync.RWMutex{}, + c: make(chan string), + } +} + +func (l *Limiter) Start() { + go l.Manage() + return +} + +func (l *Limiter) Manage() { + for { + select { + case ip := <-l.c: + l.m.Lock() + if _, ok := l.current[ip]; ok { + l.current[ip] += 1 + } else { + l.current[ip] = 1 + } + l.m.Unlock() + case <-l.ticker.C: + l.m.Lock() + start := time.Now() + count := len(l.current) + deleted := 0 + for ip, times := range l.current { + if times-l.refill <= 0 { + deleted += 1 + delete(l.current, ip) + } else { + l.current[ip] -= l.refill + } + } + l.m.Unlock() + elapsed := time.Since(start) + if count >= 1 { + log.Info().Int("ips", count).Int("forgotten", deleted).Str("duration", elapsed.String()).Msg("Refill rate limit") + } + } + } +} + +func (l *Limiter) RateLimiter(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + addr := strings.Split(r.RemoteAddr, ":")[0] + l.m.RLock() + count, ok := l.current[addr] + l.m.RUnlock() + if ok && count >= l.max { + hj, ok := w.(http.Hijacker) + if !ok { + r.Body.Close() + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + return + } + conn, _, err := hj.Hijack() + if err != nil { + panic(err) + } + + log.Warn().Str("host", r.Host).Str("uri", r.RequestURI).Str("method", r.Method).Str("remote", addr).Msg("Rate limited") + conn.Close() + return + } + l.c <- addr + next.ServeHTTP(w, r) + }) +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..bdb13ed --- /dev/null +++ b/router.go @@ -0,0 +1,142 @@ +package domainrouter + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httputil" + + "github.com/rs/zerolog/log" +) + +// ConstMap for disallowing change of elements during runtime, for threadsafty +type constMap[K comparable, V any] struct { + dirty map[K]V +} + +func NewConstMap[K comparable, V any](m map[K]V) *constMap[K, V] { + return &constMap[K, V]{ + dirty: m, + } +} + +func (m *constMap[K, V]) Get(key K) (value V, ok bool) { + value, ok = m.dirty[key] + return value, ok +} + +type Router struct { + domains *constMap[string, int] + client *http.Client +} + +func New(domains map[string]int, client *http.Client) Router { + return Router{ + domains: NewConstMap(domains), + client: client, + } +} + +func (router *Router) Route(w http.ResponseWriter, r *http.Request) { + port, ok := router.domains.Get(r.Host) + if !ok { + w.WriteHeader(http.StatusOK) + return + } + + if !dumpRequest(w, r) { + return + } + + subUrlPath := r.URL.RequestURI() + req, err := http.NewRequest(r.Method, fmt.Sprintf("http://localhost:%d%s", port, subUrlPath), r.Body) + if err != nil { + log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not create request") + w.WriteHeader(http.StatusInternalServerError) + return + } + + for name, values := range r.Header { + for _, value := range values { + req.Header.Set(name, value) + } + } + + for _, cookie := range r.Cookies() { + req.AddCookie(cookie) + } + + if !dumpRequest(w, req) { + return + } + + res, err := router.client.Do(req) + if err != nil { + log.Error().Err(err).Str("path", subUrlPath).Int("port", port).Msg("Could not complete request") + w.WriteHeader(http.StatusInternalServerError) + return + } + + cookies := res.Cookies() + for _, cookie := range cookies { + // fmt.Printf("Setting cookie, Name: %s, Value: %s\n", cookie.Name, cookie.Value) + http.SetCookie(w, cookie) + } + + if !dumpResponse(w, res) { + return + } + + if loc, err := res.Location(); !errors.Is(err, http.ErrNoLocation) { + http.Redirect(w, r, loc.RequestURI(), http.StatusFound) + } else { + for name, values := range res.Header { + for _, value := range values { + w.Header().Set(name, value) + } + } + w.WriteHeader(res.StatusCode) + + body, err := io.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + log.Error().Err(err).Msg("Could not read body") + w.WriteHeader(http.StatusInternalServerError) + return + } + + _, err = w.Write(body) + if err != nil { + log.Error().Err(err).Msg("Could not write body") + w.WriteHeader(http.StatusInternalServerError) + return + } + } +} + +func dumpRequest(w http.ResponseWriter, r *http.Request) bool { + if e := log.Debug(); e.Enabled() && r.Method == "POST" { + rDump, err := httputil.DumpRequest(r, true) + if err != nil { + log.Error().Err(err).Msg("Could not dump request") + w.WriteHeader(http.StatusInternalServerError) + return false + } + log.Debug().Str("dump", string(rDump)).Send() + } + return true +} + +func dumpResponse(w http.ResponseWriter, r *http.Response) bool { + if e := log.Trace(); e.Enabled() { + dump, err := httputil.DumpResponse(r, true) + if err != nil { + log.Error().Err(err).Msg("Could not dump response") + w.WriteHeader(http.StatusInternalServerError) + return false + } + log.Trace().Str("dump", string(dump)).Send() + } + return true +}