Add multiple remotes and round robin

This commit is contained in:
Pablu23
2025-05-19 09:30:07 +02:00
parent 403c89b068
commit 55a2d2b708
5 changed files with 73 additions and 36 deletions

View File

@@ -15,7 +15,7 @@ type Config struct {
} `yaml:"server"` } `yaml:"server"`
Hosts []struct { Hosts []struct {
Port int `yaml:"port"` Port int `yaml:"port"`
Remote string `yaml:"remote"` Remotes []string `yaml:"remotes"`
Domains []string `yaml:"domains"` Domains []string `yaml:"domains"`
Public bool `yaml:"public"` Public bool `yaml:"public"`
Secure bool `yaml:"secure"` Secure bool `yaml:"secure"`

View File

@@ -34,7 +34,9 @@ rateLimit:
hosts: hosts:
# Remote address to request # Remote address to request
- remote: localhost - remotes:
- localhost
- 192.168.2.154
# Port on which to request # Port on which to request
port: 8181 port: 8181
# Health check if announce is true # Health check if announce is true
@@ -44,13 +46,16 @@ hosts:
- localhost - localhost
- test.localhost - test.localhost
- remote: localhost - remotes:
- localhost
port: 8282 port: 8282
public: false public: false
domains: domains:
- private.localhost - private.localhost
- remote: www.google.com - remotes:
- www.google.com
- localhost
port: 443 port: 443
public: false public: false
# Uses https under the hood to communicate with the remote host # Uses https under the hood to communicate with the remote host

36
constmap.go Normal file
View File

@@ -0,0 +1,36 @@
package domainrouter
import "sync"
// ThreadMap for disallowing change of elements during runtime, for threadsafty
type ThreadMap[K comparable, V any] struct {
dirty map[K]V
rwMutex sync.RWMutex
}
func NewThreadMap[K comparable, V any](m map[K]V) *ThreadMap[K, V] {
return &ThreadMap[K, V]{
dirty: m,
rwMutex: sync.RWMutex{},
}
}
func (m *ThreadMap[K, V]) Get(key K) (value V, ok bool) {
m.rwMutex.RLock()
defer m.rwMutex.RUnlock()
value, ok = m.dirty[key]
return value, ok
}
func (m *ThreadMap[K, V]) SetValue(key K, change func(old V) V) bool {
m.rwMutex.Lock()
defer m.rwMutex.Unlock()
value, ok := m.dirty[key]
if !ok {
return ok
}
m.dirty[key] = change(value)
return ok
}

View File

@@ -7,38 +7,48 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"sync/atomic"
"github.com/pablu23/domain-router/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Router struct { type Router struct {
config *Config config *Config
domains *util.ImmutableMap[string, Host] domains *ThreadMap[string, Host]
client *http.Client client *http.Client
} }
type Host struct { type Host struct {
Port int Port int
Remote string Remotes []string
Secure bool Secure bool
Current *atomic.Uint32
} }
func New(config *Config, client *http.Client) Router { func New(config *Config, client *http.Client) Router {
m := make(map[string]Host) m := make(map[string]Host)
for _, host := range config.Hosts { for _, host := range config.Hosts {
for _, domain := range host.Domains { for _, domain := range host.Domains {
m[domain] = Host{host.Port, host.Remote, host.Secure} m[domain] = Host{host.Port, host.Remotes, host.Secure, &atomic.Uint32{}}
} }
} }
return Router{ return Router{
config: config, config: config,
domains: util.NewImmutableMap(m), domains: NewThreadMap(m),
client: client, client: client,
} }
} }
func (router *Router) roundRobin(host *Host) {
l := len(host.Remotes)
if l > 1 && host.Current.Load()+1 < uint32(l) {
host.Current.Add(1)
} else if l > 1 {
host.Current.Store(0)
}
}
func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) { func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
if !router.config.General.AnnouncePublic { if !router.config.General.AnnouncePublic {
http.NotFound(w, r) http.NotFound(w, r)
@@ -58,9 +68,9 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
healthy := true healthy := true
var url string var url string
if host.Secure { if host.Secure {
url = fmt.Sprintf("https://%s:%d/healthz", host.Remote, host.Port) url = fmt.Sprintf("https://%s:%d/healthz", host.Remotes, host.Port)
} else { } else {
url = fmt.Sprintf("http://%s:%d/healthz", host.Remote, host.Port) url = fmt.Sprintf("http://%s:%d/healthz", host.Remotes, host.Port)
} }
res, err := router.client.Get(url) res, err := router.client.Get(url)
@@ -89,13 +99,13 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func createRequest(r *http.Request, host *Host) (*http.Request, error) { func createRequest(r *http.Request, host *Host, remote string) (*http.Request, error) {
subUrlPath := r.URL.RequestURI() subUrlPath := r.URL.RequestURI()
var url string var url string
if host.Secure { if host.Secure {
url = fmt.Sprintf("https://%s:%d%s", host.Remote, host.Port, subUrlPath) url = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath)
} else { } else {
url = fmt.Sprintf("http://%s:%d%s", host.Remote, host.Port, subUrlPath) url = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath)
} }
req, err := http.NewRequest(r.Method, url, r.Body) req, err := http.NewRequest(r.Method, url, r.Body)
@@ -165,9 +175,12 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
return return
} }
req, err := createRequest(r, &host) remote := host.Remotes[host.Current.Load()]
go router.roundRobin(&host)
req, err := createRequest(r, &host, remote)
if err != nil { if err != nil {
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not create request") log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@@ -179,7 +192,7 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
res, err := router.client.Do(req) res, err := router.client.Do(req)
if err != nil { if err != nil {
log.Error().Err(err).Str("remote", host.Remote).Int("port", host.Port).Msg("Could not complete request") log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }

View File

@@ -1,17 +0,0 @@
package util
// ImmutableMap for disallowing change of elements during runtime, for threadsafty
type ImmutableMap[K comparable, V any] struct {
dirty map[K]V
}
func NewImmutableMap[K comparable, V any](m map[K]V) *ImmutableMap[K, V] {
return &ImmutableMap[K, V]{
dirty: m,
}
}
func (m *ImmutableMap[K, V]) Get(key K) (value V, ok bool) {
value, ok = m.dirty[key]
return value, ok
}