This repository has been archived on 2025-10-15. You can view files and clone it, but cannot push or open issues or pull requests.
Files
Uftp/internal/server/server.go
2023-12-08 20:39:52 +01:00

394 lines
8.6 KiB
Go

package server
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"io"
"net"
"os"
"os/signal"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/Pablu23/Uftp/internal/common"
)
type info struct {
path string
lastSync uint32
lastPckSend common.HeaderFlag
key [32]byte
time time.Time
}
type Server struct {
sessions map[common.SessionID]*info
mu sync.Mutex
rsa *rsa.PrivateKey
}
func New() (*Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
log.SetFormatter(&log.TextFormatter{
ForceColors: true,
})
return &Server{
sessions: make(map[common.SessionID]*info),
rsa: key,
}, nil
}
func (server *Server) SavePublicKeyPem() error {
file, err := os.Create("pubkey.pem")
if err != nil {
return err
}
defer file.Close()
publicKeyPEM := &pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: x509.MarshalPKCS1PublicKey(&server.rsa.PublicKey),
}
pem.Encode(file, publicKeyPEM)
return nil
}
func (server *Server) sendPacket(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
server.mu.Lock()
var key [32]byte
if info, ok := server.sessions[pck.Sid]; ok {
key = info.key
server.sessions[pck.Sid].time = time.Now()
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
secPck := common.NewSymmetricSecurePacket(key, pck)
if _, err := conn.WriteToUDP(secPck.ToBytes(), addr); err != nil {
log.Error("Could not write Packet to UDP")
return
}
}
func (server *Server) handlePacket(conn *net.UDPConn, addr *net.UDPAddr, rPacket *common.Packet) {
switch rPacket.Flag {
case common.Request:
server.sendPTE(conn, addr, rPacket)
break
case common.Ack:
server.handleAck(conn, addr, rPacket)
break
case common.Resend:
server.resend(conn, addr, rPacket)
break
default:
log.WithField("Packet Type", rPacket.Flag).Error("Unexpected Packet Type")
break
}
}
func (server *Server) resend(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
resend, err := pck.GetUint32Payload()
if err != nil {
log.Error("Error getting Resend Sync from Packet")
return
}
server.mu.Lock()
var path string
if info, ok := server.sessions[pck.Sid]; ok {
path = info.path
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
file, err := os.Open(path)
if err != nil {
log.WithError(err).WithField("File Path", path).Error("Unable to open File")
return
}
defer file.Close()
// This should be different
offset := (int64(resend) - 3) * (int64(common.MaxDataSize))
buf := make([]byte, common.MaxDataSize)
_, err = file.ReadAt(buf, offset)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).WithField("File Path", path).Error("Unable to read File")
return
}
resendPck := common.NewResendFile(pck, buf)
server.sendPacket(conn, addr, resendPck)
}
func (server *Server) handleAck(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
ack, err := pck.GetUint32Payload()
if err != nil {
log.WithError(err).Error("Getting Acknowledge from Packet")
return
}
server.mu.Lock()
if session, ok := server.sessions[pck.Sid]; ok {
if ack != session.lastSync {
log.WithFields(log.Fields{
"Expected": session.lastSync,
"Received": ack,
}).Warn("Received wrong Acknowledge")
return
}
if session.lastPckSend == common.End {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Info("Closing Session")
delete(server.sessions, pck.Sid)
server.mu.Unlock()
} else {
server.mu.Unlock()
server.sendData(conn, addr, pck)
}
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
}
func (server *Server) sendPTE(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
path, err := pck.GetFilePath()
if err != nil {
log.WithError(err).Error("Unable to get File Path")
return
}
fi, err := os.Stat(path)
if err != nil {
log.WithError(err).WithField("File Path", path).Error("Unable to open File")
return
}
fileSize := fi.Size()
ptePck := common.NewPte(uint32(fileSize), pck)
server.sendPacket(conn, addr, ptePck)
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
info.path = path
info.lastSync = ptePck.Sync
info.lastPckSend = ptePck.Flag
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
}
func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) {
var path string
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
path = info.path
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
file, err := os.Open(path)
if err != nil {
log.WithError(err).WithField("File Path", path).Error("Unable to open File")
return
}
defer file.Close()
buf := make([]byte, common.MaxDataSize)
filePck := pck
for {
r, err := file.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).WithField("File Path", path).Error("Unable to read File")
return
}
if r == 0 {
break
}
filePck = common.NewFile(filePck, buf[:r])
server.sendPacket(conn, addr, filePck)
}
eodPck := common.NewEnd(filePck)
server.mu.Lock()
if info, ok := server.sessions[pck.Sid]; ok {
info.lastSync = eodPck.Sync
info.lastPckSend = eodPck.Flag
server.mu.Unlock()
} else {
log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
return
}
server.sendPacket(conn, addr, eodPck)
}
func (server *Server) startTimeout(interuptChan chan bool) {
running := true
for running {
select {
case c := <-interuptChan:
if c {
running = false
}
break
case <-time.After(time.Second * 30):
server.cleanup()
break
}
}
}
func (server *Server) cleanup() {
server.mu.Lock()
for sid, info := range server.sessions {
if time.Now().After(info.time.Add(30 * time.Second)) {
delete(server.sessions, sid)
log.WithField("SessionID", hex.EncodeToString(sid[:])).Info("Closed session")
}
}
server.mu.Unlock()
}
func (server *Server) handleShutdown(stop chan bool) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
stop <- true
log.Info("Server is shutting down")
os.Exit(0)
}
}()
}
func (server *Server) handleConnection(conn net.Conn) {
reader := bufio.NewReader(conn)
var buf [2048]byte
r, err := reader.Read(buf[:])
if err != nil {
log.WithError(err).Warn("Could not read from Connection")
conn.Close()
return
}
// fmt.Println(buf)
rsaPck := common.RsaPacketFromBytes(buf[0:r])
key, err := rsaPck.ExtractKey(server.rsa)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).Warn("Could not extract Key")
return
}
server.mu.Lock()
server.sessions[rsaPck.Sid] = &info{
key: key,
}
server.mu.Unlock()
conn.Write([]byte("Yep"))
conn.Close()
}
func (server *Server) startManagement() {
listener, err := net.Listen("tcp", "0.0.0.0:13375")
if err != nil {
log.Fatal("Could not start listening on TCP 0.0.0.0:13375")
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
log.WithError(err).Warn("Could not accept TCP Connection")
}
go server.handleConnection(conn)
}
}
func (server *Server) Serve() {
udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:13374")
if err != nil {
log.Fatal("Could not resolve UDP Address")
}
log.Infof("Starting server on %v:%v", udpAddr.IP, udpAddr.Port)
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Fatal("Could not start listening")
}
log.Info("Started listening")
c := make(chan bool)
server.handleShutdown(c)
go server.startTimeout(c)
go server.startManagement()
for {
var buf [common.PacketSize]byte
_, addr, err := conn.ReadFromUDP(buf[0:])
if err != nil {
log.Error("Could not retrieve UDP Packet")
continue
}
secPck, err := common.SecurePacketFromBytes(buf[:])
if err != nil {
log.WithError(err).Warn("Received invalid Packet")
continue
}
var key [32]byte
server.mu.Lock()
if info, ok := server.sessions[secPck.Sid]; ok {
key = info.key
} else {
log.WithField("SessionID", hex.EncodeToString(secPck.Sid[:])).Warn("Invalid Session")
server.mu.Unlock()
continue
}
pck, err := secPck.ExtractPacket(key)
if err != nil {
log.Error("Could not extract Packet from Secure Packet")
}
server.mu.Unlock()
go server.handlePacket(conn, addr, &pck)
}
}