feat: test

This commit is contained in:
Mark Puha 2025-06-11 20:12:36 +02:00
parent c66702372d
commit f6c385f6a7
14 changed files with 240 additions and 97 deletions

View file

@ -22,18 +22,18 @@ func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) CreateJunkPackets(junks [][]byte) error {
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
if jc.aSecCfg.JunkPacketCount == 0 {
return nil
}
for i := range jc.aSecCfg.JunkPacketCount {
for range jc.aSecCfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
junks[i] = junk
*junks = append(*junks, junk)
}
return nil
}

View file

@ -35,7 +35,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
}
t.Run("valid", func(t *testing.T) {
got := make([][]byte, jc.aSecCfg.JunkPacketCount)
err := jc.CreateJunkPackets(got)
err := jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",

View file

@ -3,18 +3,22 @@ package awg
import (
"errors"
"time"
"go.uber.org/atomic"
)
// TODO: atomic/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkGeneratorHandler
ControlledJunk TagJunkGeneratorHandler
nextItime time.Time
ITimeout time.Duration // seconds
// TODO: maybe atomic?
PacketCounter uint64
IsSet bool
IsSet bool
}
func (handler *SpecialHandshakeHandler) Validate() error {
@ -29,13 +33,21 @@ func (handler *SpecialHandshakeHandler) Validate() error {
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
// TODO: distiungish between first and the rest of the packets
if !handler.SpecialJunk.IsDefined() {
return nil
}
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
return nil
}
if !handler.isTimeToSendSpecial() {
return nil
}
rv := handler.SpecialJunk.GeneratePackets()
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
return rv
@ -45,6 +57,10 @@ func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
return time.Now().After(handler.nextItime)
}
func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte {
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
if !handler.ControlledJunk.IsDefined() {
return nil
}
return handler.ControlledJunk.GeneratePackets()
}

View file

@ -10,6 +10,7 @@ import (
"time"
v2 "math/rand/v2"
// "go.uber.org/atomic"
)
type Generator interface {
@ -33,9 +34,8 @@ func (bg *BytesGenerator) Size() int {
}
func newBytesGenerator(param string) (Generator, error) {
isNotHex := !strings.HasPrefix(param, "0x") ||
!strings.HasPrefix(param, "0x") && !isHexString(param)
if isNotHex {
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
if !hasPrefix {
return nil, fmt.Errorf("not correct hex: %s", param)
}
@ -47,17 +47,6 @@ func newBytesGenerator(param string) (Generator, error) {
return &BytesGenerator{value: hex, size: len(hex)}, nil
}
func isHexString(s string) bool {
for _, char := range s {
if !((char >= '0' && char <= '9') ||
(char >= 'a' && char <= 'f') ||
(char >= 'A' && char <= 'F')) {
return false
}
}
return len(s) > 0
}
func hexToBytes(hexStr string) ([]byte, error) {
hexStr = strings.TrimPrefix(hexStr, "0x")
hexStr = strings.TrimPrefix(hexStr, "0X")
@ -110,6 +99,7 @@ type TimestampGenerator struct {
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
fmt.Printf("timestamp: %v\n", buf)
return buf
}
@ -130,6 +120,7 @@ type WaitTimeoutGenerator struct {
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
time.Sleep(wtg.waitTimeout)
return []byte{}
}
@ -139,14 +130,39 @@ func (wtg *WaitTimeoutGenerator) Size() int {
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
size, err := strconv.Atoi(param)
t, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
if size > 5000 {
return nil, fmt.Errorf("timeout size must be less than 5000ms")
if t > 5000 {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
return &WaitTimeoutGenerator{}, nil
return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
}
type PacketCounterGenerator struct {
// counter *atomic.Uint64
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
return buf
}
func (c *PacketCounterGenerator) Size() int {
return 8
}
func newPacketCounterGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
// return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
return &PacketCounterGenerator{}, nil
}

View file

@ -1,6 +1,7 @@
package awg
import (
"encoding/binary"
"fmt"
"testing"
@ -32,7 +33,14 @@ func Test_newBytesGenerator(t *testing.T) {
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value",
name: "not only hex value with X",
args: args{
param: "0X12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with x",
args: args{
param: "0x12345q",
},
@ -127,3 +135,55 @@ func Test_newRandomPacketGenerator(t *testing.T) {
})
}
}
func TestPacketCounterGenerator(t *testing.T) {
tests := []struct {
name string
param string
wantErr bool
}{
{
name: "Valid empty param",
param: "",
wantErr: false,
},
{
name: "Invalid non-empty param",
param: "anything",
wantErr: true,
},
}
for _, tc := range tests {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gen, err := newPacketCounterGenerator(tc.param)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, 8, gen.Size())
// Reset counter to known value for test
initialCount := uint64(42)
PacketCounter.Store(initialCount)
output := gen.Generate()
require.Equal(t, 8, len(output))
// Verify counter value in output
counterValue := binary.BigEndian.Uint64(output)
require.Equal(t, initialCount, counterValue)
// Increment counter and verify change
PacketCounter.Add(1)
output = gen.Generate()
counterValue = binary.BigEndian.Uint64(output)
require.Equal(t, initialCount+1, counterValue)
})
}
}

View file

@ -12,7 +12,7 @@ type TagJunkGenerator struct {
}
func newTagJunkGenerator(name string, size int) TagJunkGenerator {
return TagJunkGenerator{name: name, generators: make([]Generator, size)}
return TagJunkGenerator{name: name, generators: make([]Generator, 0, size)}
}
func (tg *TagJunkGenerator) append(generator Generator) {
@ -40,7 +40,7 @@ func (tg *TagJunkGenerator) nameIndex() (int, error) {
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
return 0, fmt.Errorf("name should be 2 char long: %w", err)
return 0, fmt.Errorf("name 2 char should be an int %w", err)
}
return index, nil
}

View file

@ -3,23 +3,26 @@ package awg
import "fmt"
type TagJunkGeneratorHandler struct {
generators []TagJunkGenerator
length int
// Jc
DefaultJunkCount int
tagGenerators []TagJunkGenerator
length int
DefaultJunkCount int // Jc
}
func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) {
handler.generators = append(handler.generators, generators)
handler.tagGenerators = append(handler.tagGenerators, generators)
handler.length++
}
func (handler *TagJunkGeneratorHandler) IsDefined() bool {
return len(handler.tagGenerators) > 0
}
// validate that packets were defined consecutively
func (handler *TagJunkGeneratorHandler) Validate() error {
seen := make([]bool, len(handler.generators))
for _, generator := range handler.generators {
seen := make([]bool, len(handler.tagGenerators))
for _, generator := range handler.tagGenerators {
index, err := generator.nameIndex()
if index > len(handler.generators) {
if index > len(handler.tagGenerators) {
return fmt.Errorf("junk packet index should be consecutive")
}
if err != nil {
@ -39,10 +42,10 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
}
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
var rv = make([][]byte, handler.length+handler.DefaultJunkCount)
for i, generator := range handler.generators {
rv[i] = make([]byte, generator.packetSize)
copy(rv[i], generator.generatePacket())
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
for i, tagGenerator := range handler.tagGenerators {
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
}
return rv

View file

@ -8,8 +8,6 @@ import (
)
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
t.Parallel()
tests := []struct {
name string
generator TagJunkGenerator
@ -28,20 +26,18 @@ func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
// Initial length should be 0
require.Equal(t, 0, handler.length)
require.Empty(t, handler.generators)
require.Empty(t, handler.tagGenerators)
// After append, length should be 1 and generator should be added
handler.AppendGenerator(tt.generator)
require.Equal(t, 1, handler.length)
require.Len(t, handler.generators, 1)
require.Equal(t, tt.generator, handler.generators[0])
require.Len(t, handler.tagGenerators, 1)
require.Equal(t, tt.generator, handler.tagGenerators[0])
})
}
}
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
generators []TagJunkGenerator
@ -49,12 +45,13 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
errMsg string
}{
{
name: "valid consecutive indices",
name: "bad start",
generators: []TagJunkGenerator{
newTagJunkGenerator("t1", 10),
newTagJunkGenerator("t2", 10),
newTagJunkGenerator("t3", 10),
newTagJunkGenerator("t4", 10),
},
wantErr: false,
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "non-consecutive indices",
@ -65,6 +62,16 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "consecutive indices",
generators: []TagJunkGenerator{
newTagJunkGenerator("t1", 10),
newTagJunkGenerator("t2", 10),
newTagJunkGenerator("t3", 10),
newTagJunkGenerator("t4", 10),
newTagJunkGenerator("t5", 10),
},
},
{
name: "nameIndex error",
generators: []TagJunkGenerator{
@ -88,16 +95,14 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
})
}
}
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
t.Parallel()
mockByte1 := []byte{0x01, 0x02}
mockByte2 := []byte{0x03, 0x04, 0x05}
mockGen1 := internal.NewMockByteGenerator(mockByte1)

View file

@ -20,7 +20,7 @@ const (
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
@ -89,6 +89,15 @@ func Parse(name, input string) (TagJunkGenerator, error) {
return TagJunkGenerator{}, fmt.Errorf("gen: %w", err)
}
// TODO: handle counter tag
// if tag.Name == CounterEnumTag {
// packetCounter, ok := generator.(*PacketCounterGenerator)
// if !ok {
// log.Fatalf("packet counter generator expected, got %T", generator)
// }
// PacketCounter = packetCounter.counter
// }
rv.append(generator)
}

View file

@ -91,7 +91,7 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
return
}
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
func genAWGConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@ -103,46 +103,35 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
args0 := append([]string(nil), cfg...)
args0 = append(args0, []string{
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
}...)
cfgs[0] = uapiCfg(args0...)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
args1 := append([]string(nil), cfg...)
args1 = append(args1, []string{
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
}...)
cfgs[1] = uapiCfg(args1...)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
@ -214,11 +203,12 @@ func (pair *testPair) Send(
// genTestPair creates a testPair.
func genTestPair(
tb testing.TB,
realSocket, withASecurity bool,
realSocket bool,
extraCfg ...string,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
if len(extraCfg) > 0 {
cfg, endpointCfg = genAWGConfigs(tb, extraCfg...)
} else {
cfg, endpointCfg = genConfigs(tb)
}
@ -265,7 +255,7 @@ func genTestPair(
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, false)
pair := genTestPair(t, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@ -275,9 +265,45 @@ func TestTwoDevicePing(t *testing.T) {
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestASecurityTwoDevicePing(t *testing.T) {
func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
pair := genTestPair(t, true,
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestAWGHandshakeDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
// "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
// "i2", "<b 0xf6ab3267fa><r 100>",
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
"j2", "<c><b 0xf6ab><t><wt 1000>",
"j3", "<t><b 0xf6ab><c><r 10>",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",
// "s1", "30",
// "s2", "40",
// "h1", "123456",
// "h2", "67543",
// "h4", "32345",
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@ -292,7 +318,7 @@ func TestUpDown(t *testing.T) {
const otrials = 10
for n := 0; n < otrials; n++ {
pair := genTestPair(t, false, false)
pair := genTestPair(t, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@ -326,7 +352,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true, false)
pair := genTestPair(t, true)
done := make(chan struct{})
const warmupIters = 10
@ -407,7 +433,7 @@ func TestConcurrencySafety(t *testing.T) {
}
func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -421,7 +447,7 @@ func BenchmarkLatency(b *testing.B) {
}
func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -465,7 +491,7 @@ func BenchmarkThroughput(b *testing.B) {
}
func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()

View file

@ -13,6 +13,7 @@ import (
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
)
type Peer struct {
@ -137,7 +138,10 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil {
var totalLen uint64
for _, b := range buffers {
peer.device.awg.HandshakeHandler.PacketCounter++
// TODO
awg.PacketCounter.Inc()
peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)

View file

@ -134,14 +134,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
}
peer.device.awg.ASecMux.RUnlock()
} else {
junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount)
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
err := peer.device.awg.JunkCreator.CreateJunkPackets(junks)
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {

1
go.mod
View file

@ -5,6 +5,7 @@ go 1.24
require (
github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0
go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.36.0
golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0

2
go.sum
View file

@ -10,6 +10,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=