From c66702372d675ca60e9180dbef8e1ce623d39cec Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 9 Jun 2025 17:36:37 +0200
Subject: [PATCH] feat: create tests
---
device/awg/internal/mock.go | 37 ++++
device/awg/special_handshake_handler.go | 5 +-
device/awg/tag_junk_generator.go | 2 +-
device/awg/tag_junk_generator_handler.go | 10 +-
device/awg/tag_junk_generator_handler_test.go | 144 ++++++++++++
device/awg/tag_junk_generator_test.go | 206 ++++++++++++++++++
6 files changed, 398 insertions(+), 6 deletions(-)
create mode 100644 device/awg/internal/mock.go
create mode 100644 device/awg/tag_junk_generator_handler_test.go
create mode 100644 device/awg/tag_junk_generator_test.go
diff --git a/device/awg/internal/mock.go b/device/awg/internal/mock.go
new file mode 100644
index 0000000..a2e1c95
--- /dev/null
+++ b/device/awg/internal/mock.go
@@ -0,0 +1,37 @@
+package internal
+
+type mockGenerator struct {
+ size int
+}
+
+func NewMockGenerator(size int) mockGenerator {
+ return mockGenerator{size: size}
+}
+
+func (m mockGenerator) Generate() []byte {
+ return make([]byte, m.size)
+}
+
+func (m mockGenerator) Size() int {
+ return m.size
+}
+
+func (m mockGenerator) Name() string {
+ return "mock"
+}
+
+type mockByteGenerator struct {
+ data []byte
+}
+
+func NewMockByteGenerator(data []byte) mockByteGenerator {
+ return mockByteGenerator{data: data}
+}
+
+func (bg mockByteGenerator) Generate() []byte {
+ return bg.data
+}
+
+func (bg mockByteGenerator) Size() int {
+ return len(bg.data)
+}
diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go
index 45b048b..17f2a53 100644
--- a/device/awg/special_handshake_handler.go
+++ b/device/awg/special_handshake_handler.go
@@ -11,6 +11,7 @@ type SpecialHandshakeHandler struct {
nextItime time.Time
ITimeout time.Duration // seconds
+
// TODO: maybe atomic?
PacketCounter uint64
IsSet bool
@@ -33,7 +34,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
return nil
}
- rv := handler.SpecialJunk.Generate()
+ rv := handler.SpecialJunk.GeneratePackets()
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
@@ -45,5 +46,5 @@ func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
}
func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte {
- return handler.ControlledJunk.Generate()
+ return handler.ControlledJunk.GeneratePackets()
}
diff --git a/device/awg/tag_junk_generator.go b/device/awg/tag_junk_generator.go
index e856177..a8b3cde 100644
--- a/device/awg/tag_junk_generator.go
+++ b/device/awg/tag_junk_generator.go
@@ -20,7 +20,7 @@ func (tg *TagJunkGenerator) append(generator Generator) {
tg.packetSize += generator.Size()
}
-func (tg *TagJunkGenerator) generate() []byte {
+func (tg *TagJunkGenerator) generatePacket() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go
index a2b0746..b8cf9e3 100644
--- a/device/awg/tag_junk_generator_handler.go
+++ b/device/awg/tag_junk_generator_handler.go
@@ -18,7 +18,11 @@ func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenera
func (handler *TagJunkGeneratorHandler) Validate() error {
seen := make([]bool, len(handler.generators))
for _, generator := range handler.generators {
- if index, err := generator.nameIndex(); err != nil {
+ index, err := generator.nameIndex()
+ if index > len(handler.generators) {
+ return fmt.Errorf("junk packet index should be consecutive")
+ }
+ if err != nil {
return fmt.Errorf("name index: %w", err)
} else {
seen[index-1] = true
@@ -34,11 +38,11 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
return nil
}
-func (handler *TagJunkGeneratorHandler) Generate() [][]byte {
+func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
var rv = make([][]byte, handler.length+handler.DefaultJunkCount)
for i, generator := range handler.generators {
rv[i] = make([]byte, generator.packetSize)
- copy(rv[i], generator.generate())
+ copy(rv[i], generator.generatePacket())
}
return rv
diff --git a/device/awg/tag_junk_generator_handler_test.go b/device/awg/tag_junk_generator_handler_test.go
new file mode 100644
index 0000000..2c36cda
--- /dev/null
+++ b/device/awg/tag_junk_generator_handler_test.go
@@ -0,0 +1,144 @@
+package awg
+
+import (
+ "testing"
+
+ "github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ generator TagJunkGenerator
+ }{
+ {
+ name: "append single generator",
+ generator: newTagJunkGenerator("t1", 10),
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ handler := &TagJunkGeneratorHandler{}
+
+ // Initial length should be 0
+ require.Equal(t, 0, handler.length)
+ require.Empty(t, handler.generators)
+
+ // After append, length should be 1 and generator should be added
+ handler.AppendGenerator(tt.generator)
+ require.Equal(t, 1, handler.length)
+ require.Len(t, handler.generators, 1)
+ require.Equal(t, tt.generator, handler.generators[0])
+ })
+ }
+}
+
+func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ generators []TagJunkGenerator
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid consecutive indices",
+ generators: []TagJunkGenerator{
+ newTagJunkGenerator("t1", 10),
+ newTagJunkGenerator("t2", 10),
+ },
+ wantErr: false,
+ },
+ {
+ name: "non-consecutive indices",
+ generators: []TagJunkGenerator{
+ newTagJunkGenerator("t1", 10),
+ newTagJunkGenerator("t3", 10), // Missing t2
+ },
+ wantErr: true,
+ errMsg: "junk packet index should be consecutive",
+ },
+ {
+ name: "nameIndex error",
+ generators: []TagJunkGenerator{
+ newTagJunkGenerator("error", 10),
+ },
+ wantErr: true,
+ errMsg: "name must be 2 character long",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ handler := &TagJunkGeneratorHandler{}
+ for _, gen := range tt.generators {
+ handler.AppendGenerator(gen)
+ }
+
+ err := handler.Validate()
+ if tt.wantErr {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errMsg)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
+ t.Parallel()
+
+ mockByte1 := []byte{0x01, 0x02}
+ mockByte2 := []byte{0x03, 0x04, 0x05}
+ mockGen1 := internal.NewMockByteGenerator(mockByte1)
+ mockGen2 := internal.NewMockByteGenerator(mockByte2)
+
+ tests := []struct {
+ name string
+ setupGenerator func() []TagJunkGenerator
+ expected [][]byte
+ }{
+ {
+ name: "generate with no default junk",
+ setupGenerator: func() []TagJunkGenerator {
+ tg1 := newTagJunkGenerator("t1", 0)
+ tg1.append(mockGen1)
+ tg1.append(mockGen2)
+ tg2 := newTagJunkGenerator("t2", 0)
+ tg2.append(mockGen2)
+ tg2.append(mockGen1)
+
+ return []TagJunkGenerator{tg1, tg2}
+ },
+ expected: [][]byte{
+ append(mockByte1, mockByte2...),
+ append(mockByte2, mockByte1...),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ handler := &TagJunkGeneratorHandler{}
+ generators := tt.setupGenerator()
+ for _, gen := range generators {
+ handler.AppendGenerator(gen)
+ }
+
+ result := handler.GeneratePackets()
+ require.Equal(t, result, tt.expected)
+ })
+ }
+}
diff --git a/device/awg/tag_junk_generator_test.go b/device/awg/tag_junk_generator_test.go
new file mode 100644
index 0000000..ee4b77e
--- /dev/null
+++ b/device/awg/tag_junk_generator_test.go
@@ -0,0 +1,206 @@
+package awg
+
+import (
+ "testing"
+
+ "github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewTagJunkGenerator(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ genName string
+ size int
+ expected TagJunkGenerator
+ }{
+ {
+ name: "Create new generator with empty name",
+ genName: "",
+ size: 0,
+ expected: TagJunkGenerator{
+ name: "",
+ packetSize: 0,
+ generators: make([]Generator, 0),
+ },
+ },
+ {
+ name: "Create new generator with valid name",
+ genName: "T1",
+ size: 0,
+ expected: TagJunkGenerator{
+ name: "T1",
+ packetSize: 0,
+ generators: make([]Generator, 0),
+ },
+ },
+ {
+ name: "Create new generator with non-zero size",
+ genName: "T2",
+ size: 5,
+ expected: TagJunkGenerator{
+ name: "T2",
+ packetSize: 0,
+ generators: make([]Generator, 5),
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ result := newTagJunkGenerator(tc.genName, tc.size)
+ require.Equal(t, tc.expected.name, result.name)
+ require.Equal(t, tc.expected.packetSize, result.packetSize)
+ require.Len(t, result.generators, len(tc.expected.generators))
+ })
+ }
+}
+
+func TestTagJunkGeneratorAppend(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ initialState TagJunkGenerator
+ mockSize int
+ expectedLength int
+ expectedSize int
+ }{
+ {
+ name: "Append to empty generator",
+ initialState: newTagJunkGenerator("T1", 0),
+ mockSize: 5,
+ expectedLength: 1,
+ expectedSize: 5,
+ },
+ {
+ name: "Append to non-empty generator",
+ initialState: TagJunkGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)},
+ mockSize: 7,
+ expectedLength: 3, // 2 existing + 1 new
+ expectedSize: 17, // 10 + 7
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ tg := tc.initialState
+ mockGen := internal.NewMockGenerator(tc.mockSize)
+
+ tg.append(mockGen)
+
+ require.Equal(t, tc.expectedLength, len(tg.generators))
+ require.Equal(t, tc.expectedSize, tg.packetSize)
+ })
+ }
+}
+
+func TestTagJunkGeneratorGenerate(t *testing.T) {
+ t.Parallel()
+
+ // Create mock generators for testing
+ mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
+ mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
+
+ testCases := []struct {
+ name string
+ setupGenerator func() TagJunkGenerator
+ expected []byte
+ }{
+ {
+ name: "Generate with empty generators",
+ setupGenerator: func() TagJunkGenerator {
+ return newTagJunkGenerator("T1", 0)
+ },
+ expected: []byte{},
+ },
+ {
+ name: "Generate with single generator",
+ setupGenerator: func() TagJunkGenerator {
+ tg := newTagJunkGenerator("T2", 0)
+ tg.append(mockGen1)
+ return tg
+ },
+ expected: []byte{0x01, 0x02},
+ },
+ {
+ name: "Generate with multiple generators",
+ setupGenerator: func() TagJunkGenerator {
+ tg := newTagJunkGenerator("T3", 0)
+ tg.append(mockGen1)
+ tg.append(mockGen2)
+ return tg
+ },
+ expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ tg := tc.setupGenerator()
+ result := tg.generatePacket()
+
+ require.Equal(t, tc.expected, result)
+ })
+ }
+}
+
+func TestTagJunkGeneratorNameIndex(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ generatorName string
+ expectedIndex int
+ expectError bool
+ }{
+ {
+ name: "Valid name with digit",
+ generatorName: "T5",
+ expectedIndex: 5,
+ expectError: false,
+ },
+ {
+ name: "Invalid name - too short",
+ generatorName: "T",
+ expectError: true,
+ },
+ {
+ name: "Invalid name - too long",
+ generatorName: "T55",
+ expectError: true,
+ },
+ {
+ name: "Invalid name - non-digit second character",
+ generatorName: "TX",
+ expectError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ tg := TagJunkGenerator{name: tc.generatorName}
+ index, err := tg.nameIndex()
+
+ if tc.expectError {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedIndex, index)
+ }
+ })
+ }
+}