diff --git a/.gitignore b/.gitignore index 786b0c0..f83594c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ bin/ testFile testFile.recv -testFile2 \ No newline at end of file +testFile2 +out/ \ No newline at end of file diff --git a/client.go b/client.go index 90ba26f..5bd2986 100644 --- a/client.go +++ b/client.go @@ -1,16 +1,19 @@ package main import ( + "encoding/hex" "fmt" "net" "os" "sort" + "time" ) func GetFile(path string) { request := NewRequest(path) udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:13374") + // udpAddr, err := net.ResolveUDPAddr("udp", "192.168.2.145:13374") if err != nil { fmt.Println(err) @@ -31,7 +34,7 @@ func GetFile(path string) { } bytes := make([]byte, PacketSize) - file, err := os.Create(path + ".recv") + file, err := os.Create("out/" + hex.EncodeToString(request.sid[:]) + ".recv") if err != nil { panic(err) } @@ -88,61 +91,72 @@ func GetFile(path string) { }) lostPackets := make([]uint32, 0) - lastSync := ackPck.sync - needResend := false - for _, i := range recvPackets { - if lastSync+1 != i { + + for i := ackPck.sync + 1; i < endPacket.sync; i++ { + if b, _ := contains(recvPackets, i); !b { lostPackets = append(lostPackets, i) - needResend = true } - lastSync = i } - if !needResend { - ack := NewAck(&endPacket) - conn.Write(ack.ToBytes()) + for _, i := range lostPackets { + fmt.Println(i) } - // sort.Slice(recvPackets, func(i, j int) bool { - // pckI := recvPackets[i] - // pckJ := recvPackets[j] - // return pckI.sync < pckJ.sync - // }) + lastPacket := ackPck - // endPacketFound := false - // needResend := false - // lastSync := request.sync - // fmt.Println(lastSync) - // endPacketSync, err := endPacket.GetSync() - // if err != nil { - // panic(err) - // } + for { + if len(lostPackets) == 0 { + break + } - // for _, packet := range recvPackets { - // // fmt.Println(packet.sync) - // // offset := (int64(packet.sync)-1)*PacketSize - int64(HeaderSize) - // // data := packet.data - // // fmt.Printf("Data: %v Offset: %v\n", data, offset) + for _, sync := range lostPackets { - // if lastSync+1 != packet.sync { - // fmt.Printf("Need Packet %v resend\n", lastSync+1) - // // Add to slice - // needResend = true - // continue - // } + fmt.Printf("Request resend for %v\n", sync) + resend := NewResend(uint32(sync), lastPacket) + conn.Write(resend.ToBytes()) + lastPacket = resend - // fmt.Printf("Writing Packet %v to file\n", packet.sync) + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) - // _, err = file.Write(packet.data) - // if err != nil { - // panic(err) - // } + _, _, err = conn.ReadFrom(bytes) + if err != nil { + if e, ok := err.(net.Error); !ok || !e.Timeout() { + // If it's not a timeout, log the error as usual + panic(err) + } + continue + } - // if packet.sync == endPacketSync { - // endPacketFound = true - // } + pck := PacketFromBytes(bytes) + offset := (int64(pck.sync) - int64(ackPck.sync+1)) * (PacketSize - int64(HeaderSize)) + // fmt.Printf("Sync: %v, Offset: %v\n", pck.sync, offset) - // lastSync = packet.sync - // } + _, err = file.WriteAt(pck.data, offset) + if err != nil { + panic(err) + } + _, index := contains(lostPackets, pck.sync) + fmt.Printf("Removing sync %v from LostPackets\n", pck.sync) + lostPackets = remove(lostPackets, index) + + } + } + + ack := NewAck(&endPacket) + conn.Write(ack.ToBytes()) +} + +func remove(s []uint32, i int) []uint32 { + s[i] = s[len(s)-1] + return s[:len(s)-1] +} + +func contains(s []uint32, e uint32) (bool, int) { + for i, a := range s { + if a == e { + return true, i + } + } + return false, 0 } diff --git a/main.go b/main.go index 8016dd6..b449b65 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,10 @@ const PacketSize = 504 func main() { if os.Args[1] == "server" { - server := New() + server, err := New() + if err != nil { + panic(err) + } server.Serve() } else { GetFile(os.Args[2]) diff --git a/packets.go b/packets.go index 43bd4e5..06074b1 100644 --- a/packets.go +++ b/packets.go @@ -21,8 +21,8 @@ type Packet struct { } func PacketFromBytes(bytes []byte) Packet { - sid := SessionID(bytes[0:32]) - flag := HeaderFlag(bytes[32]) + flag := HeaderFlag(bytes[0]) + sid := SessionID(bytes[1:33]) sync := binary.LittleEndian.Uint32(bytes[33:37]) dataLength := binary.LittleEndian.Uint32(bytes[37:41]) pck := Packet{ @@ -30,7 +30,7 @@ func PacketFromBytes(bytes []byte) Packet { flag: flag, sync: sync, dataLength: dataLength, - data: bytes[41 : 41+dataLength], + data: bytes[HeaderSize : HeaderSize+int(dataLength)], } return pck } @@ -64,6 +64,32 @@ func NewRequest(path string) *Packet { } } +func (pck *Packet) GetUint32Payload() (uint32, error) { + flag := pck.flag + if flag != PTE && flag != Ack && flag != End && flag != Resend { + return 0, errors.New(fmt.Sprintf("Can not get Sync from Packet Type with flag: %v", flag)) + } + return binary.LittleEndian.Uint32(pck.data), nil +} + +func (pck *Packet) GetFilePath() (string, error) { + if pck.flag != Request { + return "", errors.New("Can not get FilePath from Packet that is not Request") + } + return string(pck.data), nil +} + +func NewResendFile(resendPck *Packet, data []byte) *Packet { + sync, _ := resendPck.GetUint32Payload() + return &Packet{ + sid: resendPck.sid, + flag: File, + sync: sync, + dataLength: uint32(len(data)), + data: data, + } +} + func NewFile(lastPck *Packet, data []byte) *Packet { return &Packet{ sid: lastPck.sid, @@ -110,25 +136,10 @@ func NewPte(fileSize uint32, lastPck *Packet) *Packet { } } -func (pck *Packet) GetUint32Payload() (uint32, error) { - flag := pck.flag - if flag != PTE && flag != Ack && flag != End && flag != Resend { - return 0, errors.New(fmt.Sprintf("Can not get Sync from Packet Type with flag: %v", flag)) - } - return binary.LittleEndian.Uint32(pck.data), nil -} - -func (pck *Packet) GetFilePath() (string, error) { - if pck.flag != Request { - return "", errors.New("Can not get FilePath from Packet that is not Request") - } - return string(pck.data), nil -} - func (pck *Packet) ToBytes() []byte { arr := make([]byte, HeaderSize+int(pck.dataLength)) - copy(arr[0:32], pck.sid[:]) - arr[32] = byte(pck.flag) + arr[0] = byte(pck.flag) + copy(arr[1:33], pck.sid[:]) binary.LittleEndian.PutUint32(arr[33:37], pck.sync) binary.LittleEndian.PutUint32(arr[37:41], pck.dataLength) copy(arr[41:], pck.data) diff --git a/server.go b/server.go index 4ada059..175211d 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,8 @@ package main import ( + "crypto/rand" + "crypto/rsa" "encoding/hex" "errors" "fmt" @@ -13,16 +15,25 @@ type info struct { path string lastSync uint32 lastPckSend HeaderFlag + key [32]byte } type Server struct { sessions map[SessionID]*info + rsa *rsa.PrivateKey } -func New() *Server { +func New() (*Server, error) { + key, err := rsa.GenerateKey(rand.Reader, 4096) + + if err != nil { + return nil, err + } + return &Server{ sessions: make(map[SessionID]*info), - } + rsa: key, + }, nil } func (server *Server) handlePacket(conn *net.UDPConn, addr *net.UDPAddr, rPacket *Packet) { @@ -33,9 +44,42 @@ func (server *Server) handlePacket(conn *net.UDPConn, addr *net.UDPAddr, rPacket case Ack: server.handleAck(conn, addr, rPacket) break + case Resend: + server.resend(conn, addr, rPacket) } } +func (server *Server) resend(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet) { + resend, err := pck.GetUint32Payload() + if err != nil { + panic(err) + } + + path := server.sessions[pck.sid].path + file, err := os.Open(path) + if err != nil { + panic(err) + } + defer file.Close() + + // This should be different + offset := (int64(resend) - 3) * (PacketSize - int64(HeaderSize)) + // fmt.Printf("Requested Sync: %v, Calculated Offset: %v\n", resend, offset) + buf := make([]byte, PacketSize-HeaderSize) + + _, err = file.ReadAt(buf, offset) + if err != nil && !errors.Is(err, io.EOF) { + panic(err) + } + + fmt.Printf("Resending Packet %v\n", resend) + + resendPck := NewResendFile(pck, buf) + + conn.WriteToUDP(resendPck.ToBytes(), addr) + +} + func (server *Server) handleAck(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet) { ack, err := pck.GetUint32Payload() if err != nil { @@ -88,11 +132,7 @@ func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet if err != nil { panic(err) } - - // // ONLY FOR TEST - // firstPacket := true - // var firstFilePckt Packet - // // END TEST + defer file.Close() buf := make([]byte, PacketSize-HeaderSize) filePck := pck @@ -107,27 +147,9 @@ func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet filePck = NewFile(filePck, buf[:r]) fmt.Printf("Sending File Packet %v\n", filePck.sync) - // // ONLY FOR TEST - // if firstPacket { - // firstPacket = false - // firstFilePckt = Packet{ - // sid: filePck.sid, - // flag: File, - // sync: filePck.sync, - // dataLength: filePck.dataLength, - // data: make([]byte, filePck.dataLength), - // } - - // copy(firstFilePckt.data, filePck.data) - - // } else { - // // END conn.WriteToUDP(filePck.ToBytes(), addr) - // } } - // conn.WriteToUDP(firstFilePckt.ToBytes(), addr) - eodPck := NewEnd(filePck) server.sessions[pck.sid].lastSync = eodPck.sync server.sessions[pck.sid].lastPckSend = eodPck.flag