feat: complete s4 logic

This commit is contained in:
Mark Puha 2025-07-03 06:38:00 +02:00
parent 05fbf0feb0
commit 5e03df9fbd
3 changed files with 18 additions and 23 deletions

View file

@ -17,7 +17,6 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -130,7 +129,6 @@ func (pair *testPair) Send(
tb testing.TB, tb testing.TB,
ping SendDirection, ping SendDirection,
done chan struct{}, done chan struct{},
optTransportJunk ...int,
) { ) {
tb.Helper() tb.Helper()
p0, p1 := pair[0], pair[1] p0, p1 := pair[0], pair[1]
@ -139,11 +137,6 @@ func (pair *testPair) Send(
p0, p1 = p1, p0 p0, p1 = p1, p0
} }
transportJunk := 0
if len(optTransportJunk) > 0 {
transportJunk = optTransportJunk[0]
}
msg := tuntest.Ping(p0.ip, p1.ip) msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg p1.tun.Outbound <- msg
timer := time.NewTimer(6 * time.Second) timer := time.NewTimer(6 * time.Second)
@ -151,10 +144,7 @@ func (pair *testPair) Send(
var err error var err error
select { select {
case msgRecv := <-p0.tun.Inbound: case msgRecv := <-p0.tun.Inbound:
fmt.Printf("%x\n", msg) if !bytes.Equal(msg, msgRecv) {
fmt.Printf("%x\n", msgRecv[transportJunk:])
fmt.Printf("%x\n", msgRecv)
if !bytes.Equal(msg, msgRecv[transportJunk:]) {
err = fmt.Errorf("%s did not transit correctly", ping) err = fmt.Errorf("%s did not transit correctly", ping)
} }
case <-timer.C: case <-timer.C:
@ -238,26 +228,24 @@ func TestTwoDevicePing(t *testing.T) {
func TestAWGDevicePing(t *testing.T) { func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
transportJunk := 5
pair := genTestPair(t, true, pair := genTestPair(t, true,
"jc", "5", "jc", "5",
"jmin", "500", "jmin", "500",
"jmax", "1000", "jmax", "1000",
// "s1", "30", "s1", "30",
// "s2", "40", "s2", "40",
"s3", "50", "s3", "50",
"s4", strconv.Itoa(transportJunk), "s4", "5",
// "h1", "123456", "h1", "123456",
// "h2", "67543", "h2", "67543",
// "h3", "123123", "h3", "123123",
// "h4", "32345", "h4", "32345",
) )
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil, transportJunk) pair.Send(t, Ping, nil)
}) })
t.Run("ping 1.0.0.2", func(t *testing.T) { t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil, transportJunk) pair.Send(t, Pong, nil)
}) })
} }

View file

@ -165,11 +165,17 @@ func (device *Device) RoutineReceiveIncoming(
continue continue
} }
packet = packet[transportJunkSize:] // remove junk from bufsArrs by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy(bufsArrs[i][:size], packet[transportJunkSize:])
size -= transportJunkSize
// need to reinitialize packet as well
packet = packet[:size]
} }
} else { } else {
msgType = binary.LittleEndian.Uint32(packet[:4]) msgType = binary.LittleEndian.Uint32(packet[:4])
} }
switch msgType { switch msgType {
// check if transport // check if transport

View file

@ -584,6 +584,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
for _, elem := range elemsContainer.elems { for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize { if len(elem.packet) != MessageKeepaliveSize {
dataSent = true dataSent = true
junkedHeader, err := device.awg.CreateTransportHeaderJunk() junkedHeader, err := device.awg.CreateTransportHeaderJunk()
if err != nil { if err != nil {
device.log.Errorf("%v - %v", device, err) device.log.Errorf("%v - %v", device, err)