Files
domain-router/router.go

406 lines
9.0 KiB
Go

package domainrouter
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"slices"
"strings"
"sync/atomic"
"github.com/rs/zerolog/log"
)
type Router struct {
config *Config
domains *ThreadMap[string, Host]
client *http.Client
}
type Host struct {
Port int
Remotes []string
Secure bool
Current *atomic.Uint32
}
func New(config *Config, client *http.Client) Router {
m := make(map[string]Host)
for _, host := range config.Hosts {
for _, domain := range host.Domains {
m[domain] = Host{host.Port, host.Remotes, host.Secure, &atomic.Uint32{}}
}
}
return Router{
config: config,
domains: NewThreadMap(m),
client: client,
}
}
func (router *Router) roundRobin(host *Host) {
l := len(host.Remotes)
if l > 1 && host.Current.Load()+1 < uint32(l) {
host.Current.Add(1)
} else if l > 1 {
host.Current.Store(0)
}
}
func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
if !router.config.General.AnnouncePublic {
http.NotFound(w, r)
return
}
result := make([]struct {
Domain string
Healthy bool
}, 0)
for _, host := range router.config.Hosts {
if !host.Public {
continue
}
healthy := true
for _, remote := range host.Remotes {
var url string
if host.Secure {
url = fmt.Sprintf("https://%s:%d/healthz", remote, host.Port)
} else {
url = fmt.Sprintf("http://%s:%d/healthz", remote, host.Port)
}
res, err := router.client.Get(url)
if err != nil {
log.Warn().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Unhealthy")
healthy = false
} else if res.StatusCode != 200 {
healthy = false
}
}
for _, domain := range host.Domains {
result = append(result, struct {
Domain string
Healthy bool
}{domain, healthy})
}
}
data, err := json.Marshal(&result)
if err != nil {
log.Error().Err(err).Msg("Could not json encode Healthz")
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "application/json")
w.Write(data)
w.WriteHeader(http.StatusOK)
}
func rewriteRequestURL(r *http.Request, host *Host, remote string) error {
subUrlPath := r.URL.RequestURI()
var uri string
if host.Secure {
uri = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath)
} else {
uri = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath)
}
remoteUrl, err := url.Parse(uri)
if err != nil {
return err
}
r.RequestURI = ""
r.URL.Scheme = remoteUrl.Scheme
r.URL.Host = remoteUrl.Host
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
r.Header.Set("Cache-Control", "no-store, no-cache, max-age=0, must-revalidate, proxy-revalidate")
return nil
}
func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If trace enabled dump incoming request, could break request so exit early if that happens
if !dumpRequest(w, r) {
return
}
host, ok := router.domains.Get(r.Host)
if !ok {
log.Warn().Str("host", r.Host).Msg("Could not find Host")
w.WriteHeader(http.StatusOK)
return
}
remote := host.Remotes[host.Current.Load()]
go router.roundRobin(&host)
// Copy request
// Copy body with Buffer Pool
ctx := r.Context()
outreq := r.Clone(ctx)
if r.ContentLength == 0 {
outreq.Body = nil
}
if outreq.Body != nil {
defer outreq.Body.Close()
}
reqUpType := upgradeType(outreq.Header)
if !isPrintableAscii(reqUpType) {
log.Error().Str("request_upgrade_type", reqUpType).Msg("Client tried to switch to invalid protocol")
return
}
removeHopByHopHeaders(outreq.Header)
if slices.Contains(r.Header.Values("Te"), "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
stripClientProvidedXForwardHeaders(outreq.Header)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
err := rewriteRequestURL(outreq, &host, remote)
if err != nil {
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request")
w.WriteHeader(http.StatusInternalServerError)
return
}
// Dump created request
if !dumpRequest(w, outreq) {
return
}
res, err := router.client.Do(outreq)
if err != nil {
log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request")
w.WriteHeader(http.StatusInternalServerError)
return
}
// If trace enabled dump response
if !dumpResponse(w, res) {
return
}
if res.StatusCode == http.StatusSwitchingProtocols {
router.handleUpgradeResponse(w, res, outreq)
return
}
removeHopByHopHeaders(res.Header)
copyHeader(w.Header(), res.Header)
w.WriteHeader(res.StatusCode)
err = router.copyResponse(w, res.Body)
res.Body.Close()
}
func (router *Router) copyResponse(dst http.ResponseWriter, src io.ReadCloser) error {
buf := make([]byte, 32*1024)
var written int64
for {
nr, rerr := src.Read(buf)
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
log.Error().Err(rerr).Msg("Could not copy body")
return rerr
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return werr
}
}
if rerr != nil && rerr == io.EOF {
return nil
}
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (router *Router) handleUpgradeResponse(w http.ResponseWriter, res *http.Response, req *http.Request) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
if !strings.EqualFold(reqUpType, resUpType) {
log.Error().Str("response_upgrade_type", resUpType).Str("request_upgrade_type", reqUpType).Msg("Response and Request Upgrade type do not match")
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
log.Error().Msg("Could not switch protocols with non writable body")
return
}
rc := http.NewResponseController(w)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
log.Error().Type("response_writer_type", w).Msg("Could not switch protocols using non-Hijacker ResponseWriter")
return
}
backConnCloseCh := make(chan bool)
go func() {
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
log.Error().Err(hijackErr).Msg("Hijack failed on protocol switch")
return
}
defer conn.Close()
copyHeader(w.Header(), res.Header)
res.Header = w.Header()
res.Body = nil
if err := res.Write(brw); err != nil {
log.Error().Err(err).Msg("Could not write")
return
}
if err := brw.Flush(); err != nil {
log.Error().Err(err).Msg("Could not flush")
return
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
err := <-errc
if err == nil {
err = <-errc
}
}
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
var errCopyDone = errors.New("hijacked connection copy complete")
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
if _, err := io.Copy(c.user, c.backend); err != nil {
errc <- err
return
}
// backend conn has reached EOF so propogate close write to user conn
if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
errc <- wc.CloseWrite()
return
}
errc <- errCopyDone
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
if _, err := io.Copy(c.backend, c.user); err != nil {
errc <- err
return
}
// user conn has reached EOF so propogate close write to backend conn
if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
errc <- wc.CloseWrite()
return
}
errc <- errCopyDone
}
func stripClientProvidedXForwardHeaders(header http.Header) {
header.Del("Forwarded")
header.Del("X-Forwarded-For")
header.Del("X-Forwarded-Host")
header.Del("X-Forwarded-Proto")
}
func removeHopByHopHeaders(header http.Header) {
for _, f := range header.Values("Connection") {
if strings.TrimSpace(f) != "" {
header.Del(f)
}
}
}
func isPrintableAscii(reqUpType string) bool {
for _, c := range reqUpType {
if c < 32 && c > 126 {
return false
}
}
return true
}
func upgradeType(header http.Header) string {
for _, val := range header.Values("Connection") {
if strings.ToLower(val) == "upgrade" {
return header.Get("Upgrade")
}
}
return ""
}
func dumpRequest(w http.ResponseWriter, r *http.Request) bool {
if e := log.Trace(); e.Enabled() && r.Method == "POST" {
rDump, err := httputil.DumpRequest(r, true)
if err != nil {
log.Error().Err(err).Msg("Could not dump request")
w.WriteHeader(http.StatusInternalServerError)
return false
}
log.Trace().Str("dump", string(rDump)).Msg("Dumping Request")
}
return true
}
func dumpResponse(w http.ResponseWriter, r *http.Response) bool {
if e := log.Trace(); e.Enabled() {
dump, err := httputil.DumpResponse(r, true)
if err != nil {
log.Error().Err(err).Msg("Could not dump response")
w.WriteHeader(http.StatusInternalServerError)
return false
}
log.Trace().Str("dump", string(dump)).Msg("Dumping Response")
}
return true
}