Update domain router to work differntly but similarly to httpuitl reverseproxy
This commit is contained in:
@@ -40,7 +40,7 @@ func main() {
|
|||||||
|
|
||||||
router := domainrouter.New(config, client)
|
router := domainrouter.New(config, client)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", router.Route)
|
mux.HandleFunc("/", router.ServeHTTP)
|
||||||
|
|
||||||
if config.General.AnnouncePublic {
|
if config.General.AnnouncePublic {
|
||||||
h, err := url.JoinPath("/", config.General.HealthEndpoint)
|
h, err := url.JoinPath("/", config.General.HealthEndpoint)
|
||||||
|
|||||||
13
config.yaml
13
config.yaml
@@ -5,7 +5,7 @@ server:
|
|||||||
certFile: server.crt
|
certFile: server.crt
|
||||||
keyFile: server.key
|
keyFile: server.key
|
||||||
acme:
|
acme:
|
||||||
enabled: true
|
enabled: false
|
||||||
email: me@pablu.de
|
email: me@pablu.de
|
||||||
keyFile: userKey.key
|
keyFile: userKey.key
|
||||||
caDirUrl: https://192.168.2.154:14000/dir
|
caDirUrl: https://192.168.2.154:14000/dir
|
||||||
@@ -29,7 +29,7 @@ logging:
|
|||||||
|
|
||||||
|
|
||||||
rateLimit:
|
rateLimit:
|
||||||
enabled: true
|
enabled: false
|
||||||
# How many requests per ip adress are allowed
|
# How many requests per ip adress are allowed
|
||||||
bucketSize: 50
|
bucketSize: 50
|
||||||
# How many requests per ip address are refilled
|
# How many requests per ip address are refilled
|
||||||
@@ -44,7 +44,6 @@ hosts:
|
|||||||
# Remote address to request
|
# Remote address to request
|
||||||
- remotes:
|
- remotes:
|
||||||
- localhost
|
- localhost
|
||||||
- 192.168.2.154
|
|
||||||
# Port on which to request
|
# Port on which to request
|
||||||
port: 8181
|
port: 8181
|
||||||
# Health check if announce is true
|
# Health check if announce is true
|
||||||
@@ -61,6 +60,14 @@ hosts:
|
|||||||
domains:
|
domains:
|
||||||
- private.localhost
|
- private.localhost
|
||||||
|
|
||||||
|
- remotes:
|
||||||
|
- localhost
|
||||||
|
port: 5173
|
||||||
|
public: false
|
||||||
|
domains:
|
||||||
|
- hitstar.localhost
|
||||||
|
- hipstar.localhost
|
||||||
|
|
||||||
# - remotes:
|
# - remotes:
|
||||||
# - www.google.com
|
# - www.google.com
|
||||||
# - localhost
|
# - localhost
|
||||||
|
|||||||
306
router.go
306
router.go
@@ -1,12 +1,16 @@
|
|||||||
package domainrouter
|
package domainrouter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -103,87 +107,30 @@ func (router *Router) Healthz(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRequest(r *http.Request, host *Host, remote string) (*http.Request, error) {
|
func rewriteRequestURL(r *http.Request, host *Host, remote string) error {
|
||||||
subUrlPath := r.URL.RequestURI()
|
subUrlPath := r.URL.RequestURI()
|
||||||
var url string
|
var uri string
|
||||||
if host.Secure {
|
if host.Secure {
|
||||||
url = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath)
|
uri = fmt.Sprintf("https://%s:%d%s", remote, host.Port, subUrlPath)
|
||||||
} else {
|
} else {
|
||||||
url = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath)
|
uri = fmt.Sprintf("http://%s:%d%s", remote, host.Port, subUrlPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(r.Method, url, r.Body)
|
remoteUrl, err := url.Parse(uri)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
copyRequestHeader(r, req)
|
|
||||||
req.Header.Set("X-Forwarded-For", r.RemoteAddr)
|
|
||||||
req.Header.Set("Cache-Control", "no-store, no-cache, max-age=0, must-revalidate, proxy-revalidate")
|
|
||||||
|
|
||||||
for _, cookie := range r.Cookies() {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var etagHeaders = map[string]struct{}{
|
|
||||||
"ETag": {},
|
|
||||||
"If-Modified-Since": {},
|
|
||||||
"If-Match": {},
|
|
||||||
"If-None-Match": {},
|
|
||||||
"If-Range": {},
|
|
||||||
"If-Unmodified-Since": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyRequestHeader(origin *http.Request, destination *http.Request) {
|
|
||||||
for name, values := range origin.Header {
|
|
||||||
|
|
||||||
// Skip etag Headers
|
|
||||||
if _, ok := etagHeaders[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
destination.Header.Set(name, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyResponseHeader(w http.ResponseWriter, res *http.Response) {
|
|
||||||
for name, values := range res.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
w.Header().Set(name, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.WriteHeader(res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyCookies(w http.ResponseWriter, res *http.Response) {
|
|
||||||
cookies := res.Cookies()
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
http.SetCookie(w, cookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyBody(w http.ResponseWriter, res *http.Response) error {
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
defer res.Body.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = w.Write(body)
|
r.RequestURI = ""
|
||||||
if err != nil {
|
r.URL.Scheme = remoteUrl.Scheme
|
||||||
return err
|
r.URL.Host = remoteUrl.Host
|
||||||
}
|
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
|
||||||
|
r.Header.Set("Cache-Control", "no-store, no-cache, max-age=0, must-revalidate, proxy-revalidate")
|
||||||
|
|
||||||
w.WriteHeader(res.StatusCode)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
|
func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// If trace enabled dump incoming request, could break request so exit early if that happens
|
// If trace enabled dump incoming request, could break request so exit early if that happens
|
||||||
if !dumpRequest(w, r) {
|
if !dumpRequest(w, r) {
|
||||||
return
|
return
|
||||||
@@ -199,7 +146,41 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
|
|||||||
remote := host.Remotes[host.Current.Load()]
|
remote := host.Remotes[host.Current.Load()]
|
||||||
go router.roundRobin(&host)
|
go router.roundRobin(&host)
|
||||||
|
|
||||||
req, err := createRequest(r, &host, remote)
|
// Copy request
|
||||||
|
// Copy body with Buffer Pool
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
outreq := r.Clone(ctx)
|
||||||
|
if r.ContentLength == 0 {
|
||||||
|
outreq.Body = nil
|
||||||
|
}
|
||||||
|
if outreq.Body != nil {
|
||||||
|
defer outreq.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
reqUpType := upgradeType(outreq.Header)
|
||||||
|
if !isPrintableAscii(reqUpType) {
|
||||||
|
log.Error().Str("request_upgrade_type", reqUpType).Msg("Client tried to switch to invalid protocol")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
removeHopByHopHeaders(outreq.Header)
|
||||||
|
|
||||||
|
if slices.Contains(r.Header.Values("Te"), "trailers") {
|
||||||
|
outreq.Header.Set("Te", "trailers")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reqUpType != "" {
|
||||||
|
outreq.Header.Set("Connection", "Upgrade")
|
||||||
|
outreq.Header.Set("Upgrade", reqUpType)
|
||||||
|
}
|
||||||
|
|
||||||
|
stripClientProvidedXForwardHeaders(outreq.Header)
|
||||||
|
|
||||||
|
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||||
|
outreq.Header.Set("User-Agent", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rewriteRequestURL(outreq, &host, remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request")
|
log.Error().Err(err).Bool("secure", host.Secure).Str("remote", remote).Int("port", host.Port).Msg("Could not create request")
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
@@ -207,11 +188,11 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dump created request
|
// Dump created request
|
||||||
if !dumpRequest(w, req) {
|
if !dumpRequest(w, outreq) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := router.client.Do(req)
|
res, err := router.client.Do(outreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request")
|
log.Error().Err(err).Str("remote", remote).Int("port", host.Port).Msg("Could not complete request")
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
@@ -223,35 +204,180 @@ func (router *Router) Route(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCookies(w, res)
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
applyResponseHeader(w, res)
|
router.handleUpgradeResponse(w, res, outreq)
|
||||||
|
|
||||||
// Exit early because its a redirect
|
|
||||||
// Maybe this should be before applying cookies or after applying headers
|
|
||||||
if !handleLocation(w, r, res) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = applyBody(w, res)
|
removeHopByHopHeaders(res.Header)
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Could not apply body")
|
copyHeader(w.Header(), res.Header)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
w.WriteHeader(res.StatusCode)
|
||||||
|
err = router.copyResponse(w, res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *Router) copyResponse(dst http.ResponseWriter, src io.ReadCloser) error {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
var written int64
|
||||||
|
for {
|
||||||
|
nr, rerr := src.Read(buf)
|
||||||
|
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
|
||||||
|
log.Error().Err(rerr).Msg("Could not copy body")
|
||||||
|
return rerr
|
||||||
|
}
|
||||||
|
if nr > 0 {
|
||||||
|
nw, werr := dst.Write(buf[:nr])
|
||||||
|
if nw > 0 {
|
||||||
|
written += int64(nw)
|
||||||
|
}
|
||||||
|
if werr != nil {
|
||||||
|
return werr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rerr != nil && rerr == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleLocation(w http.ResponseWriter, r *http.Request, res *http.Response) bool {
|
func copyHeader(dst, src http.Header) {
|
||||||
if loc, err := res.Location(); err == nil {
|
for k, vv := range src {
|
||||||
http.Redirect(w, r, loc.String(), http.StatusFound)
|
for _, v := range vv {
|
||||||
return false
|
dst.Add(k, v)
|
||||||
} else if !errors.Is(err, http.ErrNoLocation) {
|
}
|
||||||
log.Error().Err(err).Msg("Could not extract location")
|
}
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
}
|
||||||
return false
|
|
||||||
|
func (router *Router) handleUpgradeResponse(w http.ResponseWriter, res *http.Response, req *http.Request) {
|
||||||
|
reqUpType := upgradeType(req.Header)
|
||||||
|
resUpType := upgradeType(res.Header)
|
||||||
|
if !strings.EqualFold(reqUpType, resUpType) {
|
||||||
|
log.Error().Str("response_upgrade_type", resUpType).Str("request_upgrade_type", reqUpType).Msg("Response and Request Upgrade type do not match")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
log.Error().Msg("Could not switch protocols with non writable body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
conn, brw, hijackErr := rc.Hijack()
|
||||||
|
if errors.Is(hijackErr, http.ErrNotSupported) {
|
||||||
|
log.Error().Type("response_writer_type", w).Msg("Could not switch protocols using non-Hijacker ResponseWriter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backConnCloseCh := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-req.Context().Done():
|
||||||
|
case <-backConnCloseCh:
|
||||||
|
}
|
||||||
|
backConn.Close()
|
||||||
|
}()
|
||||||
|
defer close(backConnCloseCh)
|
||||||
|
|
||||||
|
if hijackErr != nil {
|
||||||
|
log.Error().Err(hijackErr).Msg("Hijack failed on protocol switch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
copyHeader(w.Header(), res.Header)
|
||||||
|
|
||||||
|
res.Header = w.Header()
|
||||||
|
res.Body = nil
|
||||||
|
if err := res.Write(brw); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Could not write")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Could not flush")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
go spc.copyToBackend(errc)
|
||||||
|
go spc.copyFromBackend(errc)
|
||||||
|
|
||||||
|
err := <-errc
|
||||||
|
if err == nil {
|
||||||
|
err = <-errc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type switchProtocolCopier struct {
|
||||||
|
user, backend io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
var errCopyDone = errors.New("hijacked connection copy complete")
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||||
|
if _, err := io.Copy(c.user, c.backend); err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// backend conn has reached EOF so propogate close write to user conn
|
||||||
|
if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
|
||||||
|
errc <- wc.CloseWrite()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc <- errCopyDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||||
|
if _, err := io.Copy(c.backend, c.user); err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// user conn has reached EOF so propogate close write to backend conn
|
||||||
|
if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
|
||||||
|
errc <- wc.CloseWrite()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
errc <- errCopyDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripClientProvidedXForwardHeaders(header http.Header) {
|
||||||
|
header.Del("Forwarded")
|
||||||
|
header.Del("X-Forwarded-For")
|
||||||
|
header.Del("X-Forwarded-Host")
|
||||||
|
header.Del("X-Forwarded-Proto")
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeHopByHopHeaders(header http.Header) {
|
||||||
|
for _, f := range header.Values("Connection") {
|
||||||
|
if strings.TrimSpace(f) != "" {
|
||||||
|
header.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPrintableAscii(reqUpType string) bool {
|
||||||
|
for _, c := range reqUpType {
|
||||||
|
if c < 32 && c > 126 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func upgradeType(header http.Header) string {
|
||||||
|
for _, val := range header.Values("Connection") {
|
||||||
|
if strings.ToLower(val) == "upgrade" {
|
||||||
|
return header.Get("Upgrade")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func dumpRequest(w http.ResponseWriter, r *http.Request) bool {
|
func dumpRequest(w http.ResponseWriter, r *http.Request) bool {
|
||||||
if e := log.Trace(); e.Enabled() && r.Method == "POST" {
|
if e := log.Trace(); e.Enabled() && r.Method == "POST" {
|
||||||
rDump, err := httputil.DumpRequest(r, true)
|
rDump, err := httputil.DumpRequest(r, true)
|
||||||
|
|||||||
Reference in New Issue
Block a user