This repository has been archived on 2025-10-15. You can view files and clone it, but cannot push or open issues or pull requests.
Files
Uftp/internal/server/server.go

517 lines
11 KiB
Go

package server
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"path/filepath"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/Pablu23/Uftp/internal/common"
)
type info struct {
path string
lastSync uint32
lastPckSend common.HeaderFlag
key [32]byte
time time.Time
}
type Server struct {
sessions map[common.SessionID]*info
mu sync.Mutex
rsa *rsa.PrivateKey
options *Options
parentFilePath string
}
func New(opts ...func(*Options)) (*Server, error) {
options := NewDefaultOptions()
for _, opt := range opts {
opt(options)
}
var key *rsa.PrivateKey
var err error
if options.LoadPrivkey {
privKey, err := os.ReadFile(options.PrivKeyPath)
if err != nil {
return nil, err
}
block, _ := pem.Decode(privKey)
key, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
} else {
key, err = rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
}
parentFilePath, err := filepath.Abs(options.Datapath)
if err != nil {
return nil, err
}
server := &Server{
sessions: make(map[common.SessionID]*info),
rsa: key,
options: options,
parentFilePath: parentFilePath,
}
if options.SavePubKey {
err = server.SavePublicKeyPem()
if err != nil {
return nil, err
}
}
if options.SavePrivKey {
err = server.SavePrivateKeyPem()
if err != nil {
return nil, err
}
}
log.SetFormatter(&log.TextFormatter{
ForceColors: true,
})
return server, nil
}
func (server *Server) SavePrivateKeyPem() error {
file, err := os.Create(server.options.PrivKeyPath)
if err != nil {
return err
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
log.WithError(err).Error("Could not close File")
}
}(file)
privateKeyPEM := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(server.rsa),
}
err = pem.Encode(file, privateKeyPEM)
if err != nil {
return err
}
return nil
}
func (server *Server) SavePublicKeyPem() error {
file, err := os.Create(server.options.PubKeyPath)
if err != nil {
return err
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
log.WithError(err).Error("Could not close File")
}
}(file)
publicKeyPEM := &pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: x509.MarshalPKCS1PublicKey(&server.rsa.PublicKey),
}
err = pem.Encode(file, publicKeyPEM)
if err != nil {
return err
}
return nil
}
func (server *Server) sendPacket(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
server.mu.Lock()
var key [32]byte
if info, ok := server.sessions[pck.Sid]; ok {
key = info.key
server.sessions[pck.Sid].time = time.Now()
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
secPck := common.NewSymmetricSecurePacket(key, pck)
if _, err := conn.WriteToUDP(secPck.ToBytes(), addr); err != nil { // && !errors.Is(err)
log.WithError(err).Error("Could not write Packet to UDP")
fmt.Println(err)
return
} else if err != nil && errors.Is(err, bufio.ErrBufferFull) {
time.Sleep(time.Millisecond * 10)
server.sendPacket(conn, addr, pck)
return
}
}
func (server *Server) handlePacket(conn *net.UDPConn, addr *net.UDPAddr, rPacket *common.Packet) {
switch rPacket.Flag {
case common.Request:
server.sendPTE(conn, addr, rPacket)
break
case common.Ack:
server.handleAck(conn, addr, rPacket)
break
case common.Resend:
server.resend(conn, addr, rPacket)
break
default:
log.WithField("Packet Type", rPacket.Flag).Error("Unexpected Packet Type")
break
}
}
func (server *Server) resend(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
resend, err := pck.GetUint32Payload()
if err != nil {
log.Error("Error getting Resend Sync from Packet")
return
}
server.mu.Lock()
var path string
if info, ok := server.sessions[pck.Sid]; ok {
path = info.path
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
file, err := os.Open(path)
if err != nil {
log.WithError(err).WithField("File Path", path).Error("Unable to open File")
return
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
log.WithError(err).Error("Could not close File")
}
}(file)
// This should be different
offset := (int64(resend) - 3) * (int64(common.MaxDataSize))
buf := make([]byte, common.MaxDataSize)
_, err = file.ReadAt(buf, offset)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).WithField("File Path", path).Error("Unable to read File")
return
}
resendPck := common.NewResendFile(pck, buf)
server.sendPacket(conn, addr, resendPck)
}
func (server *Server) handleAck(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
ack, err := pck.GetUint32Payload()
if err != nil {
log.WithError(err).Error("Getting Acknowledge from Packet")
return
}
server.mu.Lock()
if session, ok := server.sessions[pck.Sid]; ok {
if ack != session.lastSync {
log.WithFields(log.Fields{
"Expected": session.lastSync,
"Received": ack,
}).Warn("Received wrong Acknowledge")
return
}
if session.lastPckSend == common.End {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Info("Closing Session")
delete(server.sessions, pck.Sid)
server.mu.Unlock()
} else {
server.mu.Unlock()
server.sendData(conn, addr, pck)
}
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
}
func (server *Server) sendPTE(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
path, err := pck.GetFilePath()
if err != nil {
log.WithError(err).Error("Unable to get File Path")
return
}
file := filepath.Join(server.parentFilePath, path)
file = filepath.Clean(file)
matched, err := filepath.Match(filepath.Join(server.parentFilePath, "*"), file)
if err != nil || !matched {
log.WithFields(log.Fields{
"ParentFilePath": server.parentFilePath,
"RequestedFilePath": path,
"CleanedFilePath": file,
}).WithError(err).Warn("Requesting File out of Path")
return
}
fi, err := os.Stat(file)
if err != nil {
log.WithError(err).WithField("File Path", file).Error("Unable to open File")
return
}
fileSize := fi.Size()
ptePck := common.NewPte(uint32(fileSize), pck)
server.sendPacket(conn, addr, ptePck)
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
info.path = file
info.lastSync = ptePck.Sync
info.lastPckSend = ptePck.Flag
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
}
func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
var path string
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
path = info.path
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
file, err := os.Open(path)
if err != nil {
log.WithError(err).WithField("File Path", path).Error("Unable to open File")
return
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
log.WithError(err).Error("Could not close File")
}
}(file)
buf := make([]byte, common.MaxDataSize)
filePck := pck
for {
r, err := file.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).WithField("File Path", path).Error("Unable to read File")
return
}
if r == 0 {
break
}
filePck = common.NewFile(filePck, buf[:r])
server.sendPacket(conn, addr, filePck)
}
eodPck := common.NewEnd(filePck)
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
info.lastSync = eodPck.Sync
info.lastPckSend = eodPck.Flag
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
server.sendPacket(conn, addr, eodPck)
}
func (server *Server) startTimeout(interruptChan chan bool) {
running := true
for running {
select {
case c := <-interruptChan:
if c {
running = false
}
break
case <-time.After(time.Second * 30):
server.cleanup()
break
}
}
}
func (server *Server) cleanup() {
server.mu.Lock()
for sid, info := range server.sessions {
if time.Now().After(info.time.Add(30 * time.Second)) {
delete(server.sessions, sid)
log.WithField("SessionID", hex.EncodeToString(sid[:])).Info("Closed session")
}
}
server.mu.Unlock()
}
func (server *Server) handleShutdown(stop chan bool) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
stop <- true
log.Info("Server is shutting down")
os.Exit(0)
}
}()
}
func (server *Server) handleConnection(conn net.Conn) {
reader := bufio.NewReader(conn)
var buf [2048]byte
r, err := reader.Read(buf[:])
if err != nil {
log.WithError(err).Warn("Could not read from Connection")
err := conn.Close()
if err != nil {
log.WithError(err).Error("Could not close connection")
}
return
}
rsaPck := common.RsaPacketFromBytes(buf[0:r])
key, err := rsaPck.ExtractKey(server.rsa)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).Warn("Could not extract Key")
return
}
server.mu.Lock()
server.sessions[rsaPck.Sid] = &info{
key: key,
}
server.mu.Unlock()
_, err = conn.Write([]byte("Yep"))
if err != nil {
log.WithError(err).Error("Could not write to TCP connection")
return
}
defer func(conn net.Conn) {
err := conn.Close()
if err != nil {
log.WithError(err).Error("Could not close TCP connection")
}
}(conn)
log.WithField("SessionID", hex.EncodeToString(rsaPck.Sid[:])).Info("Started Session")
}
func (server *Server) startManagement() {
listener, err := net.Listen("tcp", "0.0.0.0:13375")
if err != nil {
log.Fatal("Could not start listening on TCP 0.0.0.0:13375")
}
defer func(listener net.Listener) {
err := listener.Close()
if err != nil {
log.WithError(err).Error("Could not close TCP Listener")
}
}(listener)
for {
conn, err := listener.Accept()
if err != nil {
log.WithError(err).Warn("Could not accept TCP Connection")
}
go server.handleConnection(conn)
}
}
func (server *Server) Serve() {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%v:13374", server.options.Address))
if err != nil {
log.Fatal("Could not resolve UDP Address")
}
log.Infof("Starting server on %v:%v", udpAddr.IP, udpAddr.Port)
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Fatal("Could not start listening")
}
log.Info("Started listening")
c := make(chan bool)
server.handleShutdown(c)
go server.startTimeout(c)
go server.startManagement()
for {
var buf [common.PacketSize]byte
_, addr, err := conn.ReadFromUDP(buf[0:])
if err != nil {
log.Error("Could not retrieve UDP Packet")
continue
}
secPck, err := common.SecurePacketFromBytes(buf[:])
if err != nil {
log.WithError(err).Warn("Received invalid Packet")
continue
}
var key [32]byte
server.mu.Lock()
if info, ok := server.sessions[secPck.Sid]; ok {
key = info.key
} else {
log.WithField("SessionID", hex.EncodeToString(secPck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
continue
}
pck, err := secPck.ExtractPacket(key)
if err != nil {
log.Error("Could not extract Packet from Secure Packet")
}
server.mu.Unlock()
go server.handlePacket(conn, addr, &pck)
}
}