diff --git a/client.go b/client.go index 5bd2986..76bce69 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package main import ( + "crypto/rand" "encoding/hex" "fmt" "net" @@ -9,9 +10,25 @@ import ( "time" ) +func SendPacket(pck *Packet, key [32]byte, conn *net.UDPConn) { + secPck := NewSymetricSecurePacket(key, pck) + fmt.Println(secPck) + if _, err := conn.Write(secPck.ToBytes()); err != nil { + panic(err) + } +} + func GetFile(path string) { request := NewRequest(path) + k := make([]byte, 32) + _, err := rand.Read(k) + if err != nil { + panic(err) + } + key := [32]byte(k) + keyExchangePck := NewRsaPacket(request.sid, key) + udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:13374") // udpAddr, err := net.ResolveUDPAddr("udp", "192.168.2.145:13374") @@ -28,11 +45,13 @@ func GetFile(path string) { os.Exit(1) } - _, err = conn.Write(request.ToBytes()) + _, err = conn.Write(keyExchangePck.ToBytes()) if err != nil { panic(err) } + SendPacket(request, key, conn) + bytes := make([]byte, PacketSize) file, err := os.Create("out/" + hex.EncodeToString(request.sid[:]) + ".recv") if err != nil { @@ -57,7 +76,7 @@ func GetFile(path string) { file.Truncate(int64(size)) ackPck := NewAck(&pck) - conn.Write(ackPck.ToBytes()) + SendPacket(ackPck, key, conn) recvPackets := make([]uint32, 0) var endPacket Packet @@ -113,7 +132,7 @@ func GetFile(path string) { fmt.Printf("Request resend for %v\n", sync) resend := NewResend(uint32(sync), lastPacket) - conn.Write(resend.ToBytes()) + SendPacket(resend, key, conn) lastPacket = resend conn.SetReadDeadline(time.Now().Add(10 * time.Second)) @@ -144,7 +163,7 @@ func GetFile(path string) { } ack := NewAck(&endPacket) - conn.Write(ack.ToBytes()) + SendPacket(ack, key, conn) } func remove(s []uint32, i int) []uint32 { diff --git a/go.mod b/go.mod index ae29ebe..a9a7408 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module pablu/uftp -go 1.21.1 \ No newline at end of file +go 1.21.1 + +require golang.org/x/crypto v0.15.0 + +require golang.org/x/sys v0.14.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4888ad1 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/packets.go b/packets.go index 06074b1..4c6f758 100644 --- a/packets.go +++ b/packets.go @@ -5,12 +5,23 @@ import ( "encoding/binary" "errors" "fmt" + + "golang.org/x/crypto/chacha20poly1305" ) const HeaderSize int = 32 + 1 + 4 + 4 +const SecureHeaderSize int = 1 + 42 + 32 + 4 type SessionID [32]byte +type SecurePacket struct { + isRsa byte // 0 = false everything else is true + nonce [24]byte + sid SessionID + dataLength uint32 + encryptedData []byte +} + type Packet struct { // headerLength uint32 sid SessionID @@ -20,6 +31,84 @@ type Packet struct { data []byte } +func NewSymetricSecurePacket(key [32]byte, pck *Packet) *SecurePacket { + sid := pck.sid + data := pck.ToBytes() + aead, err := chacha20poly1305.NewX(key[:]) + if err != nil { + panic(err) + } + + nonce := make([]byte, 24) + if _, err = rand.Read(nonce); err != nil { + panic(err) + } + + encrypted := make([]byte, len(data)+aead.Overhead()) + encrypted = aead.Seal(nil, nonce, data, nil) + + return &SecurePacket{ + isRsa: 0, + nonce: [24]byte(nonce), + sid: sid, + dataLength: uint32(len(encrypted)), + encryptedData: encrypted, + } +} + +func SecurePacketFromBytes(bytes []byte) SecurePacket { + isRsa := bytes[0] + nonce := bytes[1:25] + sid := SessionID(bytes[25:57]) + length := binary.LittleEndian.Uint32(bytes[57:61]) + enc := bytes[61 : 61+length] + + return SecurePacket{ + isRsa: isRsa, + nonce: [24]byte(nonce), + sid: sid, + encryptedData: enc, + dataLength: length, + } +} + +func (secPck *SecurePacket) ToBytes() []byte { + arr := make([]byte, SecureHeaderSize+len(secPck.encryptedData)) + arr[0] = secPck.isRsa + copy(arr[1:25], secPck.nonce[:]) + copy(arr[25:57], secPck.sid[:]) + binary.LittleEndian.PutUint32(arr[57:61], secPck.dataLength) + copy(arr[61:], secPck.encryptedData) + + return arr +} + +func (secPck *SecurePacket) ExtractPacket(key [32]byte) (*Packet, error) { + aead, err := chacha20poly1305.NewX(key[:]) + if err != nil { + panic(err) + } + data, err := aead.Open(nil, secPck.nonce[:], secPck.encryptedData, nil) + if err != nil { + return nil, err + } + packet := PacketFromBytes(data) + return &packet, nil +} + +func NewRsaPacket(sid SessionID, key [32]byte) *SecurePacket { + return &SecurePacket{ + isRsa: 1, + nonce: [24]byte(make([]byte, 24)), + sid: sid, + encryptedData: key[:], + } +} + +func (secPck *SecurePacket) ExtractKey( /*RSA HERE LATER*/ ) []byte { + return secPck.encryptedData[:32] +} + func PacketFromBytes(bytes []byte) Packet { flag := HeaderFlag(bytes[0]) sid := SessionID(bytes[1:33]) diff --git a/server.go b/server.go index 175211d..0354ec3 100644 --- a/server.go +++ b/server.go @@ -119,11 +119,9 @@ func (server *Server) sendPTE(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet) ptePck := NewPte(uint32(fileSize), pck) conn.WriteToUDP(ptePck.ToBytes(), addr) - server.sessions[pck.sid] = &info{ - path: path, - lastSync: ptePck.sync, - lastPckSend: ptePck.flag, - } + server.sessions[pck.sid].path = path + server.sessions[pck.sid].lastSync = ptePck.sync + server.sessions[pck.sid].lastPckSend = ptePck.flag } func (server *Server) sendData(conn *net.UDPConn, addr *net.UDPAddr, pck *Packet) { @@ -180,7 +178,27 @@ func (server *Server) Serve() { fmt.Println(err) return } - pck := PacketFromBytes(buf[:]) - go server.handlePacket(conn, addr, &pck) + + secPck := SecurePacketFromBytes(buf[:]) + + fmt.Println(secPck) + + if secPck.isRsa == 0 { + key := server.sessions[secPck.sid].key + pck, err := secPck.ExtractPacket(key) + if err != nil { + fmt.Println(err) + //fmt.Println("Could not establish secure connection") + } + go server.handlePacket(conn, addr, pck) + } else { + key := secPck.ExtractKey() + fmt.Println(key) + fmt.Println(secPck.sid) + server.sessions[secPck.sid] = &info{ + key: [32]byte(key), + } + + } } }