From c66702372d675ca60e9180dbef8e1ce623d39cec Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 9 Jun 2025 17:36:37 +0200 Subject: [PATCH] feat: create tests --- device/awg/internal/mock.go | 37 ++++ device/awg/special_handshake_handler.go | 5 +- device/awg/tag_junk_generator.go | 2 +- device/awg/tag_junk_generator_handler.go | 10 +- device/awg/tag_junk_generator_handler_test.go | 144 ++++++++++++ device/awg/tag_junk_generator_test.go | 206 ++++++++++++++++++ 6 files changed, 398 insertions(+), 6 deletions(-) create mode 100644 device/awg/internal/mock.go create mode 100644 device/awg/tag_junk_generator_handler_test.go create mode 100644 device/awg/tag_junk_generator_test.go diff --git a/device/awg/internal/mock.go b/device/awg/internal/mock.go new file mode 100644 index 0000000..a2e1c95 --- /dev/null +++ b/device/awg/internal/mock.go @@ -0,0 +1,37 @@ +package internal + +type mockGenerator struct { + size int +} + +func NewMockGenerator(size int) mockGenerator { + return mockGenerator{size: size} +} + +func (m mockGenerator) Generate() []byte { + return make([]byte, m.size) +} + +func (m mockGenerator) Size() int { + return m.size +} + +func (m mockGenerator) Name() string { + return "mock" +} + +type mockByteGenerator struct { + data []byte +} + +func NewMockByteGenerator(data []byte) mockByteGenerator { + return mockByteGenerator{data: data} +} + +func (bg mockByteGenerator) Generate() []byte { + return bg.data +} + +func (bg mockByteGenerator) Size() int { + return len(bg.data) +} diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go index 45b048b..17f2a53 100644 --- a/device/awg/special_handshake_handler.go +++ b/device/awg/special_handshake_handler.go @@ -11,6 +11,7 @@ type SpecialHandshakeHandler struct { nextItime time.Time ITimeout time.Duration // seconds + // TODO: maybe atomic? PacketCounter uint64 IsSet bool @@ -33,7 +34,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { return nil } - rv := handler.SpecialJunk.Generate() + rv := handler.SpecialJunk.GeneratePackets() handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout)) @@ -45,5 +46,5 @@ func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool { } func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte { - return handler.ControlledJunk.Generate() + return handler.ControlledJunk.GeneratePackets() } diff --git a/device/awg/tag_junk_generator.go b/device/awg/tag_junk_generator.go index e856177..a8b3cde 100644 --- a/device/awg/tag_junk_generator.go +++ b/device/awg/tag_junk_generator.go @@ -20,7 +20,7 @@ func (tg *TagJunkGenerator) append(generator Generator) { tg.packetSize += generator.Size() } -func (tg *TagJunkGenerator) generate() []byte { +func (tg *TagJunkGenerator) generatePacket() []byte { packet := make([]byte, 0, tg.packetSize) for _, generator := range tg.generators { packet = append(packet, generator.Generate()...) diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go index a2b0746..b8cf9e3 100644 --- a/device/awg/tag_junk_generator_handler.go +++ b/device/awg/tag_junk_generator_handler.go @@ -18,7 +18,11 @@ func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenera func (handler *TagJunkGeneratorHandler) Validate() error { seen := make([]bool, len(handler.generators)) for _, generator := range handler.generators { - if index, err := generator.nameIndex(); err != nil { + index, err := generator.nameIndex() + if index > len(handler.generators) { + return fmt.Errorf("junk packet index should be consecutive") + } + if err != nil { return fmt.Errorf("name index: %w", err) } else { seen[index-1] = true @@ -34,11 +38,11 @@ func (handler *TagJunkGeneratorHandler) Validate() error { return nil } -func (handler *TagJunkGeneratorHandler) Generate() [][]byte { +func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte { var rv = make([][]byte, handler.length+handler.DefaultJunkCount) for i, generator := range handler.generators { rv[i] = make([]byte, generator.packetSize) - copy(rv[i], generator.generate()) + copy(rv[i], generator.generatePacket()) } return rv diff --git a/device/awg/tag_junk_generator_handler_test.go b/device/awg/tag_junk_generator_handler_test.go new file mode 100644 index 0000000..2c36cda --- /dev/null +++ b/device/awg/tag_junk_generator_handler_test.go @@ -0,0 +1,144 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + generator TagJunkGenerator + }{ + { + name: "append single generator", + generator: newTagJunkGenerator("t1", 10), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + handler := &TagJunkGeneratorHandler{} + + // Initial length should be 0 + require.Equal(t, 0, handler.length) + require.Empty(t, handler.generators) + + // After append, length should be 1 and generator should be added + handler.AppendGenerator(tt.generator) + require.Equal(t, 1, handler.length) + require.Len(t, handler.generators, 1) + require.Equal(t, tt.generator, handler.generators[0]) + }) + } +} + +func TestTagJunkGeneratorHandlerValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + generators []TagJunkGenerator + wantErr bool + errMsg string + }{ + { + name: "valid consecutive indices", + generators: []TagJunkGenerator{ + newTagJunkGenerator("t1", 10), + newTagJunkGenerator("t2", 10), + }, + wantErr: false, + }, + { + name: "non-consecutive indices", + generators: []TagJunkGenerator{ + newTagJunkGenerator("t1", 10), + newTagJunkGenerator("t3", 10), // Missing t2 + }, + wantErr: true, + errMsg: "junk packet index should be consecutive", + }, + { + name: "nameIndex error", + generators: []TagJunkGenerator{ + newTagJunkGenerator("error", 10), + }, + wantErr: true, + errMsg: "name must be 2 character long", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + handler := &TagJunkGeneratorHandler{} + for _, gen := range tt.generators { + handler.AppendGenerator(gen) + } + + err := handler.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTagJunkGeneratorHandlerGenerate(t *testing.T) { + t.Parallel() + + mockByte1 := []byte{0x01, 0x02} + mockByte2 := []byte{0x03, 0x04, 0x05} + mockGen1 := internal.NewMockByteGenerator(mockByte1) + mockGen2 := internal.NewMockByteGenerator(mockByte2) + + tests := []struct { + name string + setupGenerator func() []TagJunkGenerator + expected [][]byte + }{ + { + name: "generate with no default junk", + setupGenerator: func() []TagJunkGenerator { + tg1 := newTagJunkGenerator("t1", 0) + tg1.append(mockGen1) + tg1.append(mockGen2) + tg2 := newTagJunkGenerator("t2", 0) + tg2.append(mockGen2) + tg2.append(mockGen1) + + return []TagJunkGenerator{tg1, tg2} + }, + expected: [][]byte{ + append(mockByte1, mockByte2...), + append(mockByte2, mockByte1...), + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + handler := &TagJunkGeneratorHandler{} + generators := tt.setupGenerator() + for _, gen := range generators { + handler.AppendGenerator(gen) + } + + result := handler.GeneratePackets() + require.Equal(t, result, tt.expected) + }) + } +} diff --git a/device/awg/tag_junk_generator_test.go b/device/awg/tag_junk_generator_test.go new file mode 100644 index 0000000..ee4b77e --- /dev/null +++ b/device/awg/tag_junk_generator_test.go @@ -0,0 +1,206 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestNewTagJunkGenerator(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + genName string + size int + expected TagJunkGenerator + }{ + { + name: "Create new generator with empty name", + genName: "", + size: 0, + expected: TagJunkGenerator{ + name: "", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with valid name", + genName: "T1", + size: 0, + expected: TagJunkGenerator{ + name: "T1", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with non-zero size", + genName: "T2", + size: 5, + expected: TagJunkGenerator{ + name: "T2", + packetSize: 0, + generators: make([]Generator, 5), + }, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := newTagJunkGenerator(tc.genName, tc.size) + require.Equal(t, tc.expected.name, result.name) + require.Equal(t, tc.expected.packetSize, result.packetSize) + require.Len(t, result.generators, len(tc.expected.generators)) + }) + } +} + +func TestTagJunkGeneratorAppend(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + initialState TagJunkGenerator + mockSize int + expectedLength int + expectedSize int + }{ + { + name: "Append to empty generator", + initialState: newTagJunkGenerator("T1", 0), + mockSize: 5, + expectedLength: 1, + expectedSize: 5, + }, + { + name: "Append to non-empty generator", + initialState: TagJunkGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)}, + mockSize: 7, + expectedLength: 3, // 2 existing + 1 new + expectedSize: 17, // 10 + 7 + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.initialState + mockGen := internal.NewMockGenerator(tc.mockSize) + + tg.append(mockGen) + + require.Equal(t, tc.expectedLength, len(tg.generators)) + require.Equal(t, tc.expectedSize, tg.packetSize) + }) + } +} + +func TestTagJunkGeneratorGenerate(t *testing.T) { + t.Parallel() + + // Create mock generators for testing + mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02}) + mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05}) + + testCases := []struct { + name string + setupGenerator func() TagJunkGenerator + expected []byte + }{ + { + name: "Generate with empty generators", + setupGenerator: func() TagJunkGenerator { + return newTagJunkGenerator("T1", 0) + }, + expected: []byte{}, + }, + { + name: "Generate with single generator", + setupGenerator: func() TagJunkGenerator { + tg := newTagJunkGenerator("T2", 0) + tg.append(mockGen1) + return tg + }, + expected: []byte{0x01, 0x02}, + }, + { + name: "Generate with multiple generators", + setupGenerator: func() TagJunkGenerator { + tg := newTagJunkGenerator("T3", 0) + tg.append(mockGen1) + tg.append(mockGen2) + return tg + }, + expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.setupGenerator() + result := tg.generatePacket() + + require.Equal(t, tc.expected, result) + }) + } +} + +func TestTagJunkGeneratorNameIndex(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + generatorName string + expectedIndex int + expectError bool + }{ + { + name: "Valid name with digit", + generatorName: "T5", + expectedIndex: 5, + expectError: false, + }, + { + name: "Invalid name - too short", + generatorName: "T", + expectError: true, + }, + { + name: "Invalid name - too long", + generatorName: "T55", + expectError: true, + }, + { + name: "Invalid name - non-digit second character", + generatorName: "TX", + expectError: true, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := TagJunkGenerator{name: tc.generatorName} + index, err := tg.nameIndex() + + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, index) + } + }) + } +}