From f6c385f6a7ecb1a44dac1d88e473c68d6f9c41c1 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Wed, 11 Jun 2025 20:12:36 +0200
Subject: [PATCH] feat: test
---
device/awg/junk_creator.go | 6 +-
device/awg/junk_creator_test.go | 2 +-
device/awg/special_handshake_handler.go | 28 ++++--
device/awg/tag_generator.go | 52 ++++++----
device/awg/tag_generator_test.go | 62 +++++++++++-
device/awg/tag_junk_generator.go | 4 +-
device/awg/tag_junk_generator_handler.go | 27 +++---
device/awg/tag_junk_generator_handler_test.go | 35 ++++---
device/awg/tag_parser.go | 11 ++-
device/device_test.go | 94 ++++++++++++-------
device/peer.go | 6 +-
device/send.go | 7 +-
go.mod | 1 +
go.sum | 2 +
14 files changed, 240 insertions(+), 97 deletions(-)
diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go
index d67eb00..08294b1 100644
--- a/device/awg/junk_creator.go
+++ b/device/awg/junk_creator.go
@@ -22,18 +22,18 @@ func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
}
// Should be called with aSecMux RLocked
-func (jc *junkCreator) CreateJunkPackets(junks [][]byte) error {
+func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
if jc.aSecCfg.JunkPacketCount == 0 {
return nil
}
- for i := range jc.aSecCfg.JunkPacketCount {
+ for range jc.aSecCfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
- junks[i] = junk
+ *junks = append(*junks, junk)
}
return nil
}
diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go
index 7ac6cad..ce1aa65 100644
--- a/device/awg/junk_creator_test.go
+++ b/device/awg/junk_creator_test.go
@@ -35,7 +35,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
}
t.Run("valid", func(t *testing.T) {
got := make([][]byte, jc.aSecCfg.JunkPacketCount)
- err := jc.CreateJunkPackets(got)
+ err := jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go
index 17f2a53..6ba29e8 100644
--- a/device/awg/special_handshake_handler.go
+++ b/device/awg/special_handshake_handler.go
@@ -3,18 +3,22 @@ package awg
import (
"errors"
"time"
+
+ "go.uber.org/atomic"
)
+// TODO: atomic/ and better way to use this
+var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
+
type SpecialHandshakeHandler struct {
+ isFirstDone bool
SpecialJunk TagJunkGeneratorHandler
ControlledJunk TagJunkGeneratorHandler
nextItime time.Time
ITimeout time.Duration // seconds
- // TODO: maybe atomic?
- PacketCounter uint64
- IsSet bool
+ IsSet bool
}
func (handler *SpecialHandshakeHandler) Validate() error {
@@ -29,13 +33,21 @@ func (handler *SpecialHandshakeHandler) Validate() error {
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
- // TODO: distiungish between first and the rest of the packets
+ if !handler.SpecialJunk.IsDefined() {
+ return nil
+ }
+ // TODO: create tests
+ if !handler.isFirstDone {
+ handler.isFirstDone = true
+ handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
+ return nil
+ }
+
if !handler.isTimeToSendSpecial() {
return nil
}
rv := handler.SpecialJunk.GeneratePackets()
-
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
return rv
@@ -45,6 +57,10 @@ func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
return time.Now().After(handler.nextItime)
}
-func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte {
+func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
+ if !handler.ControlledJunk.IsDefined() {
+ return nil
+ }
+
return handler.ControlledJunk.GeneratePackets()
}
diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go
index 1974a6e..6e18384 100644
--- a/device/awg/tag_generator.go
+++ b/device/awg/tag_generator.go
@@ -10,6 +10,7 @@ import (
"time"
v2 "math/rand/v2"
+ // "go.uber.org/atomic"
)
type Generator interface {
@@ -33,9 +34,8 @@ func (bg *BytesGenerator) Size() int {
}
func newBytesGenerator(param string) (Generator, error) {
- isNotHex := !strings.HasPrefix(param, "0x") ||
- !strings.HasPrefix(param, "0x") && !isHexString(param)
- if isNotHex {
+ hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
+ if !hasPrefix {
return nil, fmt.Errorf("not correct hex: %s", param)
}
@@ -47,17 +47,6 @@ func newBytesGenerator(param string) (Generator, error) {
return &BytesGenerator{value: hex, size: len(hex)}, nil
}
-func isHexString(s string) bool {
- for _, char := range s {
- if !((char >= '0' && char <= '9') ||
- (char >= 'a' && char <= 'f') ||
- (char >= 'A' && char <= 'F')) {
- return false
- }
- }
- return len(s) > 0
-}
-
func hexToBytes(hexStr string) ([]byte, error) {
hexStr = strings.TrimPrefix(hexStr, "0x")
hexStr = strings.TrimPrefix(hexStr, "0X")
@@ -110,6 +99,7 @@ type TimestampGenerator struct {
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
+ fmt.Printf("timestamp: %v\n", buf)
return buf
}
@@ -130,6 +120,7 @@ type WaitTimeoutGenerator struct {
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
+ fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
time.Sleep(wtg.waitTimeout)
return []byte{}
}
@@ -139,14 +130,39 @@ func (wtg *WaitTimeoutGenerator) Size() int {
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
- size, err := strconv.Atoi(param)
+ t, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
- if size > 5000 {
- return nil, fmt.Errorf("timeout size must be less than 5000ms")
+ if t > 5000 {
+ return nil, fmt.Errorf("timeout must be less than 5000ms")
}
- return &WaitTimeoutGenerator{}, nil
+ return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
+}
+
+type PacketCounterGenerator struct {
+ // counter *atomic.Uint64
+}
+
+func (c *PacketCounterGenerator) Generate() []byte {
+ buf := make([]byte, 8)
+ // TODO: better way to handle counter tag
+ binary.BigEndian.PutUint64(buf, PacketCounter.Load())
+ fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
+ return buf
+}
+
+func (c *PacketCounterGenerator) Size() int {
+ return 8
+}
+
+func newPacketCounterGenerator(param string) (Generator, error) {
+ if len(param) != 0 {
+ return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
+ }
+
+ // return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
+ return &PacketCounterGenerator{}, nil
}
diff --git a/device/awg/tag_generator_test.go b/device/awg/tag_generator_test.go
index a5fc334..4950b33 100644
--- a/device/awg/tag_generator_test.go
+++ b/device/awg/tag_generator_test.go
@@ -1,6 +1,7 @@
package awg
import (
+ "encoding/binary"
"fmt"
"testing"
@@ -32,7 +33,14 @@ func Test_newBytesGenerator(t *testing.T) {
wantErr: fmt.Errorf("not correct hex"),
},
{
- name: "not only hex value",
+ name: "not only hex value with X",
+ args: args{
+ param: "0X12345q",
+ },
+ wantErr: fmt.Errorf("not correct hex"),
+ },
+ {
+ name: "not only hex value with x",
args: args{
param: "0x12345q",
},
@@ -127,3 +135,55 @@ func Test_newRandomPacketGenerator(t *testing.T) {
})
}
}
+
+func TestPacketCounterGenerator(t *testing.T) {
+ tests := []struct {
+ name string
+ param string
+ wantErr bool
+ }{
+ {
+ name: "Valid empty param",
+ param: "",
+ wantErr: false,
+ },
+ {
+ name: "Invalid non-empty param",
+ param: "anything",
+ wantErr: true,
+ },
+ }
+
+ for _, tc := range tests {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ gen, err := newPacketCounterGenerator(tc.param)
+ if tc.wantErr {
+ require.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, 8, gen.Size())
+
+ // Reset counter to known value for test
+ initialCount := uint64(42)
+ PacketCounter.Store(initialCount)
+
+ output := gen.Generate()
+ require.Equal(t, 8, len(output))
+
+ // Verify counter value in output
+ counterValue := binary.BigEndian.Uint64(output)
+ require.Equal(t, initialCount, counterValue)
+
+ // Increment counter and verify change
+ PacketCounter.Add(1)
+ output = gen.Generate()
+ counterValue = binary.BigEndian.Uint64(output)
+ require.Equal(t, initialCount+1, counterValue)
+ })
+ }
+}
diff --git a/device/awg/tag_junk_generator.go b/device/awg/tag_junk_generator.go
index a8b3cde..3d87a46 100644
--- a/device/awg/tag_junk_generator.go
+++ b/device/awg/tag_junk_generator.go
@@ -12,7 +12,7 @@ type TagJunkGenerator struct {
}
func newTagJunkGenerator(name string, size int) TagJunkGenerator {
- return TagJunkGenerator{name: name, generators: make([]Generator, size)}
+ return TagJunkGenerator{name: name, generators: make([]Generator, 0, size)}
}
func (tg *TagJunkGenerator) append(generator Generator) {
@@ -40,7 +40,7 @@ func (tg *TagJunkGenerator) nameIndex() (int, error) {
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
- return 0, fmt.Errorf("name should be 2 char long: %w", err)
+ return 0, fmt.Errorf("name 2 char should be an int %w", err)
}
return index, nil
}
diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go
index b8cf9e3..dfa53cf 100644
--- a/device/awg/tag_junk_generator_handler.go
+++ b/device/awg/tag_junk_generator_handler.go
@@ -3,23 +3,26 @@ package awg
import "fmt"
type TagJunkGeneratorHandler struct {
- generators []TagJunkGenerator
- length int
- // Jc
- DefaultJunkCount int
+ tagGenerators []TagJunkGenerator
+ length int
+ DefaultJunkCount int // Jc
}
func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) {
- handler.generators = append(handler.generators, generators)
+ handler.tagGenerators = append(handler.tagGenerators, generators)
handler.length++
}
+func (handler *TagJunkGeneratorHandler) IsDefined() bool {
+ return len(handler.tagGenerators) > 0
+}
+
// validate that packets were defined consecutively
func (handler *TagJunkGeneratorHandler) Validate() error {
- seen := make([]bool, len(handler.generators))
- for _, generator := range handler.generators {
+ seen := make([]bool, len(handler.tagGenerators))
+ for _, generator := range handler.tagGenerators {
index, err := generator.nameIndex()
- if index > len(handler.generators) {
+ if index > len(handler.tagGenerators) {
return fmt.Errorf("junk packet index should be consecutive")
}
if err != nil {
@@ -39,10 +42,10 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
}
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
- var rv = make([][]byte, handler.length+handler.DefaultJunkCount)
- for i, generator := range handler.generators {
- rv[i] = make([]byte, generator.packetSize)
- copy(rv[i], generator.generatePacket())
+ var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
+ for i, tagGenerator := range handler.tagGenerators {
+ rv = append(rv, make([]byte, tagGenerator.packetSize))
+ copy(rv[i], tagGenerator.generatePacket())
}
return rv
diff --git a/device/awg/tag_junk_generator_handler_test.go b/device/awg/tag_junk_generator_handler_test.go
index 2c36cda..3c5efc9 100644
--- a/device/awg/tag_junk_generator_handler_test.go
+++ b/device/awg/tag_junk_generator_handler_test.go
@@ -8,8 +8,6 @@ import (
)
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
generator TagJunkGenerator
@@ -28,20 +26,18 @@ func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
// Initial length should be 0
require.Equal(t, 0, handler.length)
- require.Empty(t, handler.generators)
+ require.Empty(t, handler.tagGenerators)
// After append, length should be 1 and generator should be added
handler.AppendGenerator(tt.generator)
require.Equal(t, 1, handler.length)
- require.Len(t, handler.generators, 1)
- require.Equal(t, tt.generator, handler.generators[0])
+ require.Len(t, handler.tagGenerators, 1)
+ require.Equal(t, tt.generator, handler.tagGenerators[0])
})
}
}
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
generators []TagJunkGenerator
@@ -49,12 +45,13 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
errMsg string
}{
{
- name: "valid consecutive indices",
+ name: "bad start",
generators: []TagJunkGenerator{
- newTagJunkGenerator("t1", 10),
- newTagJunkGenerator("t2", 10),
+ newTagJunkGenerator("t3", 10),
+ newTagJunkGenerator("t4", 10),
},
- wantErr: false,
+ wantErr: true,
+ errMsg: "junk packet index should be consecutive",
},
{
name: "non-consecutive indices",
@@ -65,6 +62,16 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
+ {
+ name: "consecutive indices",
+ generators: []TagJunkGenerator{
+ newTagJunkGenerator("t1", 10),
+ newTagJunkGenerator("t2", 10),
+ newTagJunkGenerator("t3", 10),
+ newTagJunkGenerator("t4", 10),
+ newTagJunkGenerator("t5", 10),
+ },
+ },
{
name: "nameIndex error",
generators: []TagJunkGenerator{
@@ -88,16 +95,14 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
- } else {
- require.NoError(t, err)
+ return
}
+ require.NoError(t, err)
})
}
}
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
- t.Parallel()
-
mockByte1 := []byte{0x01, 0x02}
mockByte2 := []byte{0x03, 0x04, 0x05}
mockGen1 := internal.NewMockByteGenerator(mockByte1)
diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go
index c3254c8..d6dba1b 100644
--- a/device/awg/tag_parser.go
+++ b/device/awg/tag_parser.go
@@ -20,7 +20,7 @@ const (
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
- CounterEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
+ CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
@@ -89,6 +89,15 @@ func Parse(name, input string) (TagJunkGenerator, error) {
return TagJunkGenerator{}, fmt.Errorf("gen: %w", err)
}
+ // TODO: handle counter tag
+ // if tag.Name == CounterEnumTag {
+ // packetCounter, ok := generator.(*PacketCounterGenerator)
+ // if !ok {
+ // log.Fatalf("packet counter generator expected, got %T", generator)
+ // }
+ // PacketCounter = packetCounter.counter
+ // }
+
rv.append(generator)
}
diff --git a/device/device_test.go b/device/device_test.go
index 7e735e1..b57f3e6 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -91,7 +91,7 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
return
}
-func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
+func genAWGConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@@ -103,46 +103,35 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
- cfgs[0] = uapiCfg(
+ args0 := append([]string(nil), cfg...)
+ args0 = append(args0, []string{
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
- "jc", "5",
- "jmin", "500",
- "jmax", "1000",
- "s1", "30",
- "s2", "40",
- "h1", "123456",
- "h2", "67543",
- "h4", "32345",
- "h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
- )
+ }...)
+ cfgs[0] = uapiCfg(args0...)
+
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
- cfgs[1] = uapiCfg(
+
+ args1 := append([]string(nil), cfg...)
+ args1 = append(args1, []string{
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
- "jc", "5",
- "jmin", "500",
- "jmax", "1000",
- "s1", "30",
- "s2", "40",
- "h1", "123456",
- "h2", "67543",
- "h4", "32345",
- "h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
- )
+ }...)
+
+ cfgs[1] = uapiCfg(args1...)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
@@ -214,11 +203,12 @@ func (pair *testPair) Send(
// genTestPair creates a testPair.
func genTestPair(
tb testing.TB,
- realSocket, withASecurity bool,
+ realSocket bool,
+ extraCfg ...string,
) (pair testPair) {
var cfg, endpointCfg [2]string
- if withASecurity {
- cfg, endpointCfg = genASecurityConfigs(tb)
+ if len(extraCfg) > 0 {
+ cfg, endpointCfg = genAWGConfigs(tb, extraCfg...)
} else {
cfg, endpointCfg = genConfigs(tb)
}
@@ -265,7 +255,7 @@ func genTestPair(
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
- pair := genTestPair(t, true, false)
+ pair := genTestPair(t, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@@ -275,9 +265,45 @@ func TestTwoDevicePing(t *testing.T) {
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
-func TestASecurityTwoDevicePing(t *testing.T) {
+func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t)
- pair := genTestPair(t, true, true)
+ pair := genTestPair(t, true,
+ "jc", "5",
+ "jmin", "500",
+ "jmax", "1000",
+ "s1", "30",
+ "s2", "40",
+ "h1", "123456",
+ "h2", "67543",
+ "h4", "32345",
+ "h3", "123123",
+ )
+ t.Run("ping 1.0.0.1", func(t *testing.T) {
+ pair.Send(t, Ping, nil)
+ })
+ t.Run("ping 1.0.0.2", func(t *testing.T) {
+ pair.Send(t, Pong, nil)
+ })
+}
+
+func TestAWGHandshakeDevicePing(t *testing.T) {
+ goroutineLeakCheck(t)
+ pair := genTestPair(t, true,
+ // "i1", "",
+ // "i2", "",
+ "j1", "",
+ "j2", "",
+ "j3", "",
+ // "jc", "1",
+ // "jmin", "500",
+ // "jmax", "1000",
+ // "s1", "30",
+ // "s2", "40",
+ // "h1", "123456",
+ // "h2", "67543",
+ // "h4", "32345",
+ // "h3", "123123",
+ )
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@@ -292,7 +318,7 @@ func TestUpDown(t *testing.T) {
const otrials = 10
for n := 0; n < otrials; n++ {
- pair := genTestPair(t, false, false)
+ pair := genTestPair(t, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@@ -326,7 +352,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
- pair := genTestPair(t, true, false)
+ pair := genTestPair(t, true)
done := make(chan struct{})
const warmupIters = 10
@@ -407,7 +433,7 @@ func TestConcurrencySafety(t *testing.T) {
}
func BenchmarkLatency(b *testing.B) {
- pair := genTestPair(b, true, false)
+ pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@@ -421,7 +447,7 @@ func BenchmarkLatency(b *testing.B) {
}
func BenchmarkThroughput(b *testing.B) {
- pair := genTestPair(b, true, false)
+ pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@@ -465,7 +491,7 @@ func BenchmarkThroughput(b *testing.B) {
}
func BenchmarkUAPIGet(b *testing.B) {
- pair := genTestPair(b, true, false)
+ pair := genTestPair(b, true)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()
diff --git a/device/peer.go b/device/peer.go
index 9409ca6..7705b6f 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
+ "github.com/amnezia-vpn/amneziawg-go/device/awg"
)
type Peer struct {
@@ -137,7 +138,10 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil {
var totalLen uint64
for _, b := range buffers {
- peer.device.awg.HandshakeHandler.PacketCounter++
+ // TODO
+ awg.PacketCounter.Inc()
+ peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
+
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
diff --git a/device/send.go b/device/send.go
index 2e446cd..f87d4ae 100644
--- a/device/send.go
+++ b/device/send.go
@@ -134,14 +134,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
- junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
+ peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
+ junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
}
peer.device.awg.ASecMux.RUnlock()
} else {
- junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount)
+ junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
- err := peer.device.awg.JunkCreator.CreateJunkPackets(junks)
+ err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {
diff --git a/go.mod b/go.mod
index 5896772..68a3d0d 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.24
require (
github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0
+ go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.36.0
golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
diff --git a/go.sum b/go.sum
index 4b1e64a..de233a5 100644
--- a/go.sum
+++ b/go.sum
@@ -10,6 +10,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
+go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
+go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=