diff --git a/device/device_test.go b/device/device_test.go index f9eb50c..6eb52e9 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -17,7 +17,6 @@ import ( "os/signal" "runtime" "runtime/pprof" - "strconv" "sync" "testing" "time" @@ -130,7 +129,6 @@ func (pair *testPair) Send( tb testing.TB, ping SendDirection, done chan struct{}, - optTransportJunk ...int, ) { tb.Helper() p0, p1 := pair[0], pair[1] @@ -139,11 +137,6 @@ func (pair *testPair) Send( p0, p1 = p1, p0 } - transportJunk := 0 - if len(optTransportJunk) > 0 { - transportJunk = optTransportJunk[0] - } - msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg timer := time.NewTimer(6 * time.Second) @@ -151,10 +144,7 @@ func (pair *testPair) Send( var err error select { case msgRecv := <-p0.tun.Inbound: - fmt.Printf("%x\n", msg) - fmt.Printf("%x\n", msgRecv[transportJunk:]) - fmt.Printf("%x\n", msgRecv) - if !bytes.Equal(msg, msgRecv[transportJunk:]) { + if !bytes.Equal(msg, msgRecv) { err = fmt.Errorf("%s did not transit correctly", ping) } case <-timer.C: @@ -238,26 +228,24 @@ func TestTwoDevicePing(t *testing.T) { func TestAWGDevicePing(t *testing.T) { goroutineLeakCheck(t) - transportJunk := 5 - pair := genTestPair(t, true, "jc", "5", "jmin", "500", "jmax", "1000", - // "s1", "30", - // "s2", "40", + "s1", "30", + "s2", "40", "s3", "50", - "s4", strconv.Itoa(transportJunk), - // "h1", "123456", - // "h2", "67543", - // "h3", "123123", - // "h4", "32345", + "s4", "5", + "h1", "123456", + "h2", "67543", + "h3", "123123", + "h4", "32345", ) t.Run("ping 1.0.0.1", func(t *testing.T) { - pair.Send(t, Ping, nil, transportJunk) + pair.Send(t, Ping, nil) }) t.Run("ping 1.0.0.2", func(t *testing.T) { - pair.Send(t, Pong, nil, transportJunk) + pair.Send(t, Pong, nil) }) } diff --git a/device/receive.go b/device/receive.go index 0844b79..1aaba3a 100644 --- a/device/receive.go +++ b/device/receive.go @@ -165,11 +165,17 @@ func (device *Device) RoutineReceiveIncoming( continue } - packet = packet[transportJunkSize:] + // remove junk from bufsArrs by shifting the packet + // this buffer is also used for decryption, so it needs to be corrected + copy(bufsArrs[i][:size], packet[transportJunkSize:]) + size -= transportJunkSize + // need to reinitialize packet as well + packet = packet[:size] } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) } + switch msgType { // check if transport diff --git a/device/send.go b/device/send.go index 234c391..64c5ab3 100644 --- a/device/send.go +++ b/device/send.go @@ -584,6 +584,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true + junkedHeader, err := device.awg.CreateTransportHeaderJunk() if err != nil { device.log.Errorf("%v - %v", device, err)