From 5e03df9fbd2ef3b6ab407f8ef99f66e4efb3e9ef Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 3 Jul 2025 06:38:00 +0200
Subject: [PATCH] feat: complete s4 logic
---
device/device_test.go | 32 ++++++++++----------------------
device/receive.go | 8 +++++++-
device/send.go | 1 +
3 files changed, 18 insertions(+), 23 deletions(-)
diff --git a/device/device_test.go b/device/device_test.go
index f9eb50c..6eb52e9 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -17,7 +17,6 @@ import (
"os/signal"
"runtime"
"runtime/pprof"
- "strconv"
"sync"
"testing"
"time"
@@ -130,7 +129,6 @@ func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
- optTransportJunk ...int,
) {
tb.Helper()
p0, p1 := pair[0], pair[1]
@@ -139,11 +137,6 @@ func (pair *testPair) Send(
p0, p1 = p1, p0
}
- transportJunk := 0
- if len(optTransportJunk) > 0 {
- transportJunk = optTransportJunk[0]
- }
-
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(6 * time.Second)
@@ -151,10 +144,7 @@ func (pair *testPair) Send(
var err error
select {
case msgRecv := <-p0.tun.Inbound:
- fmt.Printf("%x\n", msg)
- fmt.Printf("%x\n", msgRecv[transportJunk:])
- fmt.Printf("%x\n", msgRecv)
- if !bytes.Equal(msg, msgRecv[transportJunk:]) {
+ if !bytes.Equal(msg, msgRecv) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
case <-timer.C:
@@ -238,26 +228,24 @@ func TestTwoDevicePing(t *testing.T) {
func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t)
- transportJunk := 5
-
pair := genTestPair(t, true,
"jc", "5",
"jmin", "500",
"jmax", "1000",
- // "s1", "30",
- // "s2", "40",
+ "s1", "30",
+ "s2", "40",
"s3", "50",
- "s4", strconv.Itoa(transportJunk),
- // "h1", "123456",
- // "h2", "67543",
- // "h3", "123123",
- // "h4", "32345",
+ "s4", "5",
+ "h1", "123456",
+ "h2", "67543",
+ "h3", "123123",
+ "h4", "32345",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
- pair.Send(t, Ping, nil, transportJunk)
+ pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
- pair.Send(t, Pong, nil, transportJunk)
+ pair.Send(t, Pong, nil)
})
}
diff --git a/device/receive.go b/device/receive.go
index 0844b79..1aaba3a 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -165,11 +165,17 @@ func (device *Device) RoutineReceiveIncoming(
continue
}
- packet = packet[transportJunkSize:]
+ // remove junk from bufsArrs by shifting the packet
+ // this buffer is also used for decryption, so it needs to be corrected
+ copy(bufsArrs[i][:size], packet[transportJunkSize:])
+ size -= transportJunkSize
+ // need to reinitialize packet as well
+ packet = packet[:size]
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
}
+
switch msgType {
// check if transport
diff --git a/device/send.go b/device/send.go
index 234c391..64c5ab3 100644
--- a/device/send.go
+++ b/device/send.go
@@ -584,6 +584,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
+
junkedHeader, err := device.awg.CreateTransportHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)