From e8dc69d407b826798bd8cb9d99050e3d122ed6df Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 12 Jun 2025 06:02:45 +0200
Subject: [PATCH] fix: packet counter; test special handshake
---
device/awg/special_handshake_handler.go | 16 +++++++++---
device/awg/tag_generator.go | 33 ++++++++++++++++++------
device/awg/tag_junk_generator_handler.go | 3 +++
device/awg/tag_parser.go | 12 ++++-----
device/device_test.go | 17 ++++++++----
device/peer.go | 14 +++++++---
device/receive.go | 6 ++++-
device/send.go | 12 ++++++---
device/uapi.go | 4 +--
9 files changed, 84 insertions(+), 33 deletions(-)
diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go
index 6ba29e8..fb83a53 100644
--- a/device/awg/special_handshake_handler.go
+++ b/device/awg/special_handshake_handler.go
@@ -4,12 +4,22 @@ import (
"errors"
"time"
+ "github.com/tevino/abool"
"go.uber.org/atomic"
)
-// TODO: atomic/ and better way to use this
+// TODO: atomic?/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
+// TODO
+var WaitResponse = struct {
+ Channel chan struct{}
+ ShouldWait *abool.AtomicBool
+}{
+ make(chan struct{}, 1),
+ abool.New(),
+}
+
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkGeneratorHandler
@@ -39,7 +49,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
- handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
+ handler.nextItime = time.Now().Add(handler.ITimeout)
return nil
}
@@ -48,7 +58,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
}
rv := handler.SpecialJunk.GeneratePackets()
- handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
+ handler.nextItime = time.Now().Add(handler.ITimeout)
return rv
}
diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go
index 6e18384..4aaa727 100644
--- a/device/awg/tag_generator.go
+++ b/device/awg/tag_generator.go
@@ -99,7 +99,6 @@ type TimestampGenerator struct {
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
- fmt.Printf("timestamp: %v\n", buf)
return buf
}
@@ -120,7 +119,6 @@ type WaitTimeoutGenerator struct {
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
- fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
time.Sleep(wtg.waitTimeout)
return []byte{}
}
@@ -130,27 +128,25 @@ func (wtg *WaitTimeoutGenerator) Size() int {
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
- t, err := strconv.Atoi(param)
+ timeout, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
- if t > 5000 {
+ if timeout > 5000 {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
- return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
+ return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
}
type PacketCounterGenerator struct {
- // counter *atomic.Uint64
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
- fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
return buf
}
@@ -163,6 +159,27 @@ func newPacketCounterGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
- // return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
return &PacketCounterGenerator{}, nil
}
+
+type WaitResponseGenerator struct {
+}
+
+func (c *WaitResponseGenerator) Generate() []byte {
+ WaitResponse.ShouldWait.Set()
+ <-WaitResponse.Channel
+ WaitResponse.ShouldWait.UnSet()
+ return []byte{}
+}
+
+func (c *WaitResponseGenerator) Size() int {
+ return 0
+}
+
+func newWaitResponseGenerator(param string) (Generator, error) {
+ if len(param) != 0 {
+ return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
+ }
+
+ return &WaitResponseGenerator{}, nil
+}
diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go
index dfa53cf..934dadf 100644
--- a/device/awg/tag_junk_generator_handler.go
+++ b/device/awg/tag_junk_generator_handler.go
@@ -43,10 +43,13 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
+
for i, tagGenerator := range handler.tagGenerators {
+ PacketCounter.Inc()
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
}
+ PacketCounter.Add(uint64(handler.DefaultJunkCount))
return rv
}
diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go
index d6dba1b..1180ac6 100644
--- a/device/awg/tag_parser.go
+++ b/device/awg/tag_parser.go
@@ -19,12 +19,12 @@ const (
)
var generatorCreator = map[EnumTag]newGenerator{
- BytesEnumTag: newBytesGenerator,
- CounterEnumTag: newPacketCounterGenerator,
- TimestampEnumTag: newTimestampGenerator,
- RandomBytesEnumTag: newRandomPacketGenerator,
- WaitTimeoutEnumTag: newWaitTimeoutGenerator,
- WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
+ BytesEnumTag: newBytesGenerator,
+ CounterEnumTag: newPacketCounterGenerator,
+ TimestampEnumTag: newTimestampGenerator,
+ RandomBytesEnumTag: newRandomPacketGenerator,
+ WaitTimeoutEnumTag: newWaitTimeoutGenerator,
+ // WaitResponseEnumTag: newWaitResponseGenerator,
}
// helper map to determine enumTags are unique
diff --git a/device/device_test.go b/device/device_test.go
index b57f3e6..a72159a 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -176,7 +176,7 @@ func (pair *testPair) Send(
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
- timer := time.NewTimer(5 * time.Second)
+ timer := time.NewTimer(6 * time.Second)
defer timer.Stop()
var err error
select {
@@ -289,11 +289,12 @@ func TestAWGDevicePing(t *testing.T) {
func TestAWGHandshakeDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
- // "i1", "",
- // "i2", "",
+ "i1", "",
+ "i2", "",
"j1", "",
"j2", "",
"j3", "",
+ "itime", "60",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",
@@ -305,10 +306,16 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
- pair.Send(t, Ping, nil)
+ for {
+ pair.Send(t, Ping, nil)
+ time.Sleep(2 * time.Second)
+ }
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
- pair.Send(t, Pong, nil)
+ for {
+ pair.Send(t, Pong, nil)
+ time.Sleep(2 * time.Second)
+ }
})
}
diff --git a/device/peer.go b/device/peer.go
index 7705b6f..9c7eaab 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -138,10 +138,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil {
var totalLen uint64
for _, b := range buffers {
- // TODO
- awg.PacketCounter.Inc()
- peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
-
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
@@ -149,6 +145,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return err
}
+func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error {
+ err := peer.SendBuffers(buffers)
+ if err == nil {
+ awg.PacketCounter.Add(uint64(len(buffers)))
+ return nil
+ }
+
+ return err
+}
+
func (peer *Peer) String() string {
// The awful goo that follows is identical to:
//
diff --git a/device/receive.go b/device/receive.go
index 1875c76..6eaa746 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
}
// check size of packet
-
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAdvancedSecurityOn() {
+ // TODO:
+ // if awg.WaitResponse.ShouldWait.IsSet() {
+ // awg.WaitResponse.Channel <- struct{}{}
+ // }
+
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
diff --git a/device/send.go b/device/send.go
index f87d4ae..74015c2 100644
--- a/device/send.go
+++ b/device/send.go
@@ -134,8 +134,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
- peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
+ if junks != nil {
+ peer.device.log.Verbosef("%v - Controlled junks sent", peer)
+ }
+ } else {
+ peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.ASecMux.RUnlock()
} else {
@@ -186,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
sendBuffer = append(sendBuffer, junkedHeader)
- err = peer.SendBuffers(sendBuffer)
+ err = peer.SendBuffersCountPacket(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -242,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
- err = peer.SendBuffers([][]byte{junkedHeader})
+ err = peer.SendBuffersCountPacket([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -596,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err := peer.SendBuffers(bufs)
+ err := peer.SendBuffersCountPacket(bufs)
if dataSent {
peer.timersDataSent()
}
diff --git a/device/uapi.go b/device/uapi.go
index 216f7cd..1e95e48 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -394,9 +394,9 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
}
- device.log.Verbosef("UAPI: Updating itime %s", itime)
+ device.log.Verbosef("UAPI: Updating itime %d", itime)
- tempAwg.HandshakeHandler.ITimeout = time.Duration(itime)
+ tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
tempAwg.HandshakeHandler.IsSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)