diff --git a/config.go b/config.go index 14374a5..4a3fe25 100644 --- a/config.go +++ b/config.go @@ -15,7 +15,7 @@ type Config struct { } `yaml:"server"` Hosts []struct { Port int `yaml:"port"` - Remote string `yaml:"remote"` + Remotes []string `yaml:"remotes"` Domains []string `yaml:"domains"` Public bool `yaml:"public"` Secure bool `yaml:"secure"` diff --git a/config.yaml b/config.yaml index c40d2f3..c48b604 100644 --- a/config.yaml +++ b/config.yaml @@ -34,7 +34,9 @@ rateLimit: hosts: # Remote address to request - - remote: localhost + - remotes: + - localhost + - 192.168.2.154 # Port on which to request port: 8181 # Health check if announce is true @@ -44,13 +46,16 @@ hosts: - localhost - test.localhost - - remote: localhost + - remotes: + - localhost port: 8282 public: false domains: - private.localhost - - remote: www.google.com + - remotes: + - www.google.com + - localhost port: 443 public: false # Uses https under the hood to communicate with the remote host diff --git a/constmap.go b/constmap.go new file mode 100644 index 0000000..82f9fb5 --- /dev/null +++ b/constmap.go @@ -0,0 +1,36 @@ +package domainrouter + +import "sync" + +// ThreadMap for disallowing change of elements during runtime, for threadsafty +type ThreadMap[K comparable, V any] struct { + dirty map[K]V + rwMutex sync.RWMutex +} + +func NewThreadMap[K comparable, V any](m map[K]V) *ThreadMap[K, V] { + return &ThreadMap[K, V]{ + dirty: m, + rwMutex: sync.RWMutex{}, + } +} + +func (m *ThreadMap[K, V]) Get(key K) (value V, ok bool) { + m.rwMutex.RLock() + defer m.rwMutex.RUnlock() + value, ok = m.dirty[key] + return value, ok +} + +func (m *ThreadMap[K, V]) SetValue(key K, change func(old V) V) bool { + m.rwMutex.Lock() + defer m.rwMutex.Unlock() + + value, ok := m.dirty[key] + if !ok { + return ok + } + + m.dirty[key] = change(value) + return ok +} diff --git a/router.go b/router.go index 161b428..af84dcf 100644 --- a/router.go +++ b/router.go @@ -7,38 +7,48 @@ import ( "io" "net/http" "net/http/httputil" + "sync/atomic" - "github.com/pablu23/domain-router/util" "github.com/rs/zerolog/log" ) type Router struct { config *Config - domains *util.ImmutableMap[string, Host] + domains *ThreadMap[string, Host] client *http.Client } type Host struct { - Port int - Remote string - Secure bool + Port int + Remotes []string + Secure bool + Current *atomic.Uint32 } func New(config *Config, client *http.Client) Router { m := make(map[string]Host) for _, host := range config.Hosts { for _, domain := range host.Domains { - m[domain] = Host{host.Port, host.Remote, host.Secure} + m[domain] = Host{host.Port, host.Remotes, host.Secure, &atomic.Uint32{}} } } return Router{ config: config, - domains: util.NewImmutableMap(m), + domains: NewThreadMap(m), client: client, } } +func (router *Router) roundRobin(host *Host) { + l := len(host.Remotes) + if l > 1 && host.Current.Load()+1 < uint32(l) { + host.Current.Add(1) + } else if l > 1 { + host.Current.Store(0) + } +} + func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { if !router.config.General.AnnouncePublic { http.NotFound(w, r) @@ -58,9 +68,9 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { healthy := true var url string if host.Secure { - url = fmt.Sprintf("https://%s:%d/healthz", host.Remote, host.Port) + url = fmt.Sprintf("https://%s:%d/healthz", host.Remotes, host.Port) } else { - url = fmt.Sprintf("http://%s:%d/healthz", host.Remote, host.Port) + url = fmt.Sprintf("http://%s:%d/healthz", host.Remotes, host.Port) } res, err := router.client.Get(url) @@ -89,13 +99,13 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func createRequest(r *http.Request, host *Host) (*http.Request, error) { +func createRequest(r *http.Request, host *Host, remote string) (*http.Request, error) { subUrlPath := r.URL.RequestURI() var url string if host.Secure { - url = fmt.Sprintf("https://%s:%d%s", host.Remote, host.Port, subUrlPath) + url = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath) } else { - url = fmt.Sprintf("http://%s:%d%s", host.Remote, host.Port, subUrlPath) + url = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath) } req, err := http.NewRequest(r.Method, url, r.Body) @@ -165,9 +175,12 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { return } - req, err := createRequest(r, &host) + remote := host.Remotes[host.Current.Load()] + go router.roundRobin(&host) + + req, err := createRequest(r, &host, remote) if err != nil { - log.Error().Err(err).Bool("secure", host.Secure).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not create request") + log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request") w.WriteHeader(http.StatusInternalServerError) return } @@ -179,7 +192,7 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) { res, err := router.client.Do(req) if err != nil { - log.Error().Err(err).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not complete request") + log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request") w.WriteHeader(http.StatusInternalServerError) return } diff --git a/util/constmap.go b/util/constmap.go deleted file mode 100644 index 0d18235..0000000 --- a/util/constmap.go +++ /dev/null @@ -1,17 +0,0 @@ -package util - -// ImmutableMap for disallowing change of elements during runtime, for threadsafty -type ImmutableMap[K comparable, V any] struct { - dirty map[K]V -} - -func NewImmutableMap[K comparable, V any](m map[K]V) *ImmutableMap[K, V] { - return &ImmutableMap[K, V]{ - dirty: m, - } -} - -func (m *ImmutableMap[K, V]) Get(key K) (value V, ok bool) { - value, ok = m.dirty[key] - return value, ok -}