package server import ( "bufio" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/hex" "encoding/pem" "errors" "io" "net" "os" "os/signal" "sync" "time" log "github.com/sirupsen/logrus" "github.com/Pablu23/Uftp/internal/common" ) type info struct { path string lastSync uint32 lastPckSend common.HeaderFlag key [32]byte time time.Time } type Server struct { sessions map[common.SessionID]*info mu sync.Mutex rsa *rsa.PrivateKey } func New() (*Server, error) { key, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { return nil, err } log.SetFormatter(&log.TextFormatter{ ForceColors: true, }) return &Server{ sessions: make(map[common.SessionID]*info), rsa: key, }, nil } func (server *Server) SavePublicKeyPem() error { file, err := os.Create("pubkey.pem") if err != nil { return err } defer file.Close() publicKeyPEM := &pem.Block{ Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&server.rsa.PublicKey), } pem.Encode(file, publicKeyPEM) return nil } func (server *Server) sendPacket(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) { server.mu.Lock() var key [32]byte if info, ok := server.sessions[pck.Sid]; ok { key = info.key server.sessions[pck.Sid].time = time.Now() server.mu.Unlock() } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } secPck := common.NewSymmetricSecurePacket(key, pck) if _, err := conn.WriteToUDP(secPck.ToBytes(), addr); err != nil { log.Error("Could not write Packet to UDP") return } } func (server *Server) handlePacket(conn *net.UDPConn, addr *net.UDPAddr, rPacket *common.Packet) { switch rPacket.Flag { case common.Request: server.sendPTE(conn, addr, rPacket) break case common.Ack: server.handleAck(conn, addr, rPacket) break case common.Resend: server.resend(conn, addr, rPacket) break default: log.WithField("Packet Type", rPacket.Flag).Error("Unexpected Packet Type") break } } func (server *Server) resend(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) { resend, err := pck.GetUint32Payload() if err != nil { log.Error("Error getting Resend Sync from Packet") return } server.mu.Lock() var path string if info, ok := server.sessions[pck.Sid]; ok { path = info.path server.mu.Unlock() } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } file, err := os.Open(path) if err != nil { log.WithError(err).WithField("File Path", path).Error("Unable to open File") return } defer file.Close() // This should be different offset := (int64(resend) - 3) * (int64(common.MaxDataSize)) buf := make([]byte, common.MaxDataSize) _, err = file.ReadAt(buf, offset) if err != nil && !errors.Is(err, io.EOF) { log.WithError(err).WithField("File Path", path).Error("Unable to read File") return } resendPck := common.NewResendFile(pck, buf) server.sendPacket(conn, addr, resendPck) } func (server *Server) handleAck(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) { ack, err := pck.GetUint32Payload() if err != nil { log.WithError(err).Error("Getting Acknowledge from Packet") return } server.mu.Lock() if session, ok := server.sessions[pck.Sid]; ok { if ack != session.lastSync { log.WithFields(log.Fields{ "Expected": session.lastSync, "Received": ack, }).Warn("Received wrong Acknowledge") return } if session.lastPckSend == common.End { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Info("Closing Session") delete(server.sessions, pck.Sid) server.mu.Unlock() } else { server.mu.Unlock() server.sendData(conn, addr, pck) } } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } } func (server *Server) sendPTE(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) { path, err := pck.GetFilePath() if err != nil { log.WithError(err).Error("Unable to get File Path") return } fi, err := os.Stat(path) if err != nil { log.WithError(err).WithField("File Path", path).Error("Unable to open File") return } fileSize := fi.Size() ptePck := common.NewPte(uint32(fileSize), pck) server.sendPacket(conn, addr, ptePck) server.mu.Lock() if info, ok := server.sessions[pck.Sid]; ok { info.path = path info.lastSync = ptePck.Sync info.lastPckSend = ptePck.Flag server.mu.Unlock() } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } } func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *common.Packet) { var path string server.mu.Lock() if info, ok := server.sessions[pck.Sid]; ok { path = info.path server.mu.Unlock() } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } file, err := os.Open(path) if err != nil { log.WithError(err).WithField("File Path", path).Error("Unable to open File") return } defer file.Close() buf := make([]byte, common.MaxDataSize) filePck := pck for { r, err := file.Read(buf) if err != nil && !errors.Is(err, io.EOF) { log.WithError(err).WithField("File Path", path).Error("Unable to read File") return } if r == 0 { break } filePck = common.NewFile(filePck, buf[:r]) server.sendPacket(conn, addr, filePck) } eodPck := common.NewEnd(filePck) server.mu.Lock() if info, ok := server.sessions[pck.Sid]; ok { info.lastSync = eodPck.Sync info.lastPckSend = eodPck.Flag server.mu.Unlock() } else { log.WithField("SessionID", hex.EncodeToString(pck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() return } server.sendPacket(conn, addr, eodPck) } func (server *Server) startTimeout(interuptChan chan bool) { running := true for running { select { case c := <-interuptChan: if c { running = false } break case <-time.After(time.Second * 30): server.cleanup() break } } } func (server *Server) cleanup() { server.mu.Lock() for sid, info := range server.sessions { if time.Now().After(info.time.Add(30 * time.Second)) { delete(server.sessions, sid) log.WithField("SessionID", hex.EncodeToString(sid[:])).Info("Closed session") } } server.mu.Unlock() } func (server *Server) handleShutdown(stop chan bool) { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { for range c { stop <- true log.Info("Server is shutting down") os.Exit(0) } }() } func (server *Server) handleConnection(conn net.Conn) { reader := bufio.NewReader(conn) var buf [2048]byte r, err := reader.Read(buf[:]) if err != nil { log.WithError(err).Warn("Could not read from Connection") conn.Close() return } // fmt.Println(buf) rsaPck := common.RsaPacketFromBytes(buf[0:r]) key, err := rsaPck.ExtractKey(server.rsa) if err != nil && !errors.Is(err, io.EOF) { log.WithError(err).Warn("Could not extract Key") return } server.mu.Lock() server.sessions[rsaPck.Sid] = &info{ key: key, } server.mu.Unlock() conn.Write([]byte("Yep")) conn.Close() } func (server *Server) startManagement() { listener, err := net.Listen("tcp", "0.0.0.0:13375") if err != nil { log.Fatal("Could not start listening on TCP 0.0.0.0:13375") } defer listener.Close() for { conn, err := listener.Accept() if err != nil { log.WithError(err).Warn("Could not accept TCP Connection") } go server.handleConnection(conn) } } func (server *Server) Serve() { udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:13374") if err != nil { log.Fatal("Could not resolve UDP Address") } log.Infof("Starting server on %v:%v", udpAddr.IP, udpAddr.Port) conn, err := net.ListenUDP("udp", udpAddr) if err != nil { log.Fatal("Could not start listening") } log.Info("Started listening") c := make(chan bool) server.handleShutdown(c) go server.startTimeout(c) go server.startManagement() for { var buf [common.PacketSize]byte _, addr, err := conn.ReadFromUDP(buf[0:]) if err != nil { log.Error("Could not retrieve UDP Packet") continue } secPck, err := common.SecurePacketFromBytes(buf[:]) if err != nil { log.WithError(err).Warn("Received invalid Packet") continue } var key [32]byte server.mu.Lock() if info, ok := server.sessions[secPck.Sid]; ok { key = info.key } else { log.WithField("SessionID", hex.EncodeToString(secPck.Sid[:])).Warn("Invalid Session") server.mu.Unlock() continue } pck, err := secPck.ExtractPacket(key) if err != nil { log.Error("Could not extract Packet from Secure Packet") } server.mu.Unlock() go server.handlePacket(conn, addr, &pck) } }