diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go deleted file mode 100644 index 934dadf..0000000 --- a/device/awg/tag_junk_generator_handler.go +++ /dev/null @@ -1,55 +0,0 @@ -package awg - -import "fmt" - -type TagJunkGeneratorHandler struct { - tagGenerators []TagJunkGenerator - length int - DefaultJunkCount int // Jc -} - -func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) { - handler.tagGenerators = append(handler.tagGenerators, generators) - handler.length++ -} - -func (handler *TagJunkGeneratorHandler) IsDefined() bool { - return len(handler.tagGenerators) > 0 -} - -// validate that packets were defined consecutively -func (handler *TagJunkGeneratorHandler) Validate() error { - seen := make([]bool, len(handler.tagGenerators)) - for _, generator := range handler.tagGenerators { - index, err := generator.nameIndex() - if index > len(handler.tagGenerators) { - return fmt.Errorf("junk packet index should be consecutive") - } - if err != nil { - return fmt.Errorf("name index: %w", err) - } else { - seen[index-1] = true - } - } - - for _, found := range seen { - if !found { - return fmt.Errorf("junk packet index should be consecutive") - } - } - - return nil -} - -func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte { - var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount) - - for i, tagGenerator := range handler.tagGenerators { - PacketCounter.Inc() - rv = append(rv, make([]byte, tagGenerator.packetSize)) - copy(rv[i], tagGenerator.generatePacket()) - } - PacketCounter.Add(uint64(handler.DefaultJunkCount)) - - return rv -} diff --git a/device/awg/tag_junk_generator.go b/device/awg/tag_junk_packet_generator.go similarity index 58% rename from device/awg/tag_junk_generator.go rename to device/awg/tag_junk_packet_generator.go index 3d87a46..0de80d3 100644 --- a/device/awg/tag_junk_generator.go +++ b/device/awg/tag_junk_packet_generator.go @@ -5,22 +5,22 @@ import ( "strconv" ) -type TagJunkGenerator struct { +type TagJunkPacketGenerator struct { name string packetSize int generators []Generator } -func newTagJunkGenerator(name string, size int) TagJunkGenerator { - return TagJunkGenerator{name: name, generators: make([]Generator, 0, size)} +func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator { + return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)} } -func (tg *TagJunkGenerator) append(generator Generator) { +func (tg *TagJunkPacketGenerator) append(generator Generator) { tg.generators = append(tg.generators, generator) tg.packetSize += generator.Size() } -func (tg *TagJunkGenerator) generatePacket() []byte { +func (tg *TagJunkPacketGenerator) generatePacket() []byte { packet := make([]byte, 0, tg.packetSize) for _, generator := range tg.generators { packet = append(packet, generator.Generate()...) @@ -29,11 +29,11 @@ func (tg *TagJunkGenerator) generatePacket() []byte { return packet } -func (tg *TagJunkGenerator) Name() string { +func (tg *TagJunkPacketGenerator) Name() string { return tg.name } -func (tg *TagJunkGenerator) nameIndex() (int, error) { +func (tg *TagJunkPacketGenerator) nameIndex() (int, error) { if len(tg.name) != 2 { return 0, fmt.Errorf("name must be 2 character long: %s", tg.name) } diff --git a/device/awg/tag_junk_generator_test.go b/device/awg/tag_junk_packet_generator_test.go similarity index 82% rename from device/awg/tag_junk_generator_test.go rename to device/awg/tag_junk_packet_generator_test.go index ee4b77e..adc4e86 100644 --- a/device/awg/tag_junk_generator_test.go +++ b/device/awg/tag_junk_packet_generator_test.go @@ -14,13 +14,13 @@ func TestNewTagJunkGenerator(t *testing.T) { name string genName string size int - expected TagJunkGenerator + expected TagJunkPacketGenerator }{ { name: "Create new generator with empty name", genName: "", size: 0, - expected: TagJunkGenerator{ + expected: TagJunkPacketGenerator{ name: "", packetSize: 0, generators: make([]Generator, 0), @@ -30,7 +30,7 @@ func TestNewTagJunkGenerator(t *testing.T) { name: "Create new generator with valid name", genName: "T1", size: 0, - expected: TagJunkGenerator{ + expected: TagJunkPacketGenerator{ name: "T1", packetSize: 0, generators: make([]Generator, 0), @@ -40,7 +40,7 @@ func TestNewTagJunkGenerator(t *testing.T) { name: "Create new generator with non-zero size", genName: "T2", size: 5, - expected: TagJunkGenerator{ + expected: TagJunkPacketGenerator{ name: "T2", packetSize: 0, generators: make([]Generator, 5), @@ -52,7 +52,7 @@ func TestNewTagJunkGenerator(t *testing.T) { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - result := newTagJunkGenerator(tc.genName, tc.size) + result := newTagJunkPacketGenerator(tc.genName, tc.size) require.Equal(t, tc.expected.name, result.name) require.Equal(t, tc.expected.packetSize, result.packetSize) require.Len(t, result.generators, len(tc.expected.generators)) @@ -65,21 +65,21 @@ func TestTagJunkGeneratorAppend(t *testing.T) { testCases := []struct { name string - initialState TagJunkGenerator + initialState TagJunkPacketGenerator mockSize int expectedLength int expectedSize int }{ { name: "Append to empty generator", - initialState: newTagJunkGenerator("T1", 0), + initialState: newTagJunkPacketGenerator("T1", 0), mockSize: 5, expectedLength: 1, expectedSize: 5, }, { name: "Append to non-empty generator", - initialState: TagJunkGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)}, + initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)}, mockSize: 7, expectedLength: 3, // 2 existing + 1 new expectedSize: 17, // 10 + 7 @@ -111,20 +111,20 @@ func TestTagJunkGeneratorGenerate(t *testing.T) { testCases := []struct { name string - setupGenerator func() TagJunkGenerator + setupGenerator func() TagJunkPacketGenerator expected []byte }{ { name: "Generate with empty generators", - setupGenerator: func() TagJunkGenerator { - return newTagJunkGenerator("T1", 0) + setupGenerator: func() TagJunkPacketGenerator { + return newTagJunkPacketGenerator("T1", 0) }, expected: []byte{}, }, { name: "Generate with single generator", - setupGenerator: func() TagJunkGenerator { - tg := newTagJunkGenerator("T2", 0) + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T2", 0) tg.append(mockGen1) return tg }, @@ -132,8 +132,8 @@ func TestTagJunkGeneratorGenerate(t *testing.T) { }, { name: "Generate with multiple generators", - setupGenerator: func() TagJunkGenerator { - tg := newTagJunkGenerator("T3", 0) + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T3", 0) tg.append(mockGen1) tg.append(mockGen2) return tg @@ -192,7 +192,7 @@ func TestTagJunkGeneratorNameIndex(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - tg := TagJunkGenerator{name: tc.generatorName} + tg := TagJunkPacketGenerator{name: tc.generatorName} index, err := tg.nameIndex() if tc.expectError { diff --git a/device/awg/tag_junk_packet_generators.go b/device/awg/tag_junk_packet_generators.go new file mode 100644 index 0000000..5f40db2 --- /dev/null +++ b/device/awg/tag_junk_packet_generators.go @@ -0,0 +1,55 @@ +package awg + +import "fmt" + +type TagJunkPacketGenerators struct { + tagGenerators []TagJunkPacketGenerator + length int + DefaultJunkCount int // Jc +} + +func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) { + generators.tagGenerators = append(generators.tagGenerators, generator) + generators.length++ +} + +func (generators *TagJunkPacketGenerators) IsDefined() bool { + return len(generators.tagGenerators) > 0 +} + +// validate that packets were defined consecutively +func (generators *TagJunkPacketGenerators) Validate() error { + seen := make([]bool, len(generators.tagGenerators)) + for _, generator := range generators.tagGenerators { + index, err := generator.nameIndex() + if index > len(generators.tagGenerators) { + return fmt.Errorf("junk packet index should be consecutive") + } + if err != nil { + return fmt.Errorf("name index: %w", err) + } else { + seen[index-1] = true + } + } + + for _, found := range seen { + if !found { + return fmt.Errorf("junk packet index should be consecutive") + } + } + + return nil +} + +func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte { + var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount) + + for i, tagGenerator := range generators.tagGenerators { + PacketCounter.Inc() + rv = append(rv, make([]byte, tagGenerator.packetSize)) + copy(rv[i], tagGenerator.generatePacket()) + } + PacketCounter.Add(uint64(generators.DefaultJunkCount)) + + return rv +} diff --git a/device/awg/tag_junk_generator_handler_test.go b/device/awg/tag_junk_packet_generators_test.go similarity index 56% rename from device/awg/tag_junk_generator_handler_test.go rename to device/awg/tag_junk_packet_generators_test.go index 3c5efc9..e006426 100644 --- a/device/awg/tag_junk_generator_handler_test.go +++ b/device/awg/tag_junk_packet_generators_test.go @@ -10,11 +10,11 @@ import ( func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { tests := []struct { name string - generator TagJunkGenerator + generator TagJunkPacketGenerator }{ { name: "append single generator", - generator: newTagJunkGenerator("t1", 10), + generator: newTagJunkPacketGenerator("t1", 10), }, } @@ -22,17 +22,17 @@ func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - handler := &TagJunkGeneratorHandler{} + generators := &TagJunkPacketGenerators{} // Initial length should be 0 - require.Equal(t, 0, handler.length) - require.Empty(t, handler.tagGenerators) + require.Equal(t, 0, generators.length) + require.Empty(t, generators.tagGenerators) // After append, length should be 1 and generator should be added - handler.AppendGenerator(tt.generator) - require.Equal(t, 1, handler.length) - require.Len(t, handler.tagGenerators, 1) - require.Equal(t, tt.generator, handler.tagGenerators[0]) + generators.AppendGenerator(tt.generator) + require.Equal(t, 1, generators.length) + require.Len(t, generators.tagGenerators, 1) + require.Equal(t, tt.generator, generators.tagGenerators[0]) }) } } @@ -40,42 +40,42 @@ func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { func TestTagJunkGeneratorHandlerValidate(t *testing.T) { tests := []struct { name string - generators []TagJunkGenerator + generators []TagJunkPacketGenerator wantErr bool errMsg string }{ { name: "bad start", - generators: []TagJunkGenerator{ - newTagJunkGenerator("t3", 10), - newTagJunkGenerator("t4", 10), + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t3", 10), + newTagJunkPacketGenerator("t4", 10), }, wantErr: true, errMsg: "junk packet index should be consecutive", }, { name: "non-consecutive indices", - generators: []TagJunkGenerator{ - newTagJunkGenerator("t1", 10), - newTagJunkGenerator("t3", 10), // Missing t2 + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", 10), + newTagJunkPacketGenerator("t3", 10), // Missing t2 }, wantErr: true, errMsg: "junk packet index should be consecutive", }, { name: "consecutive indices", - generators: []TagJunkGenerator{ - newTagJunkGenerator("t1", 10), - newTagJunkGenerator("t2", 10), - newTagJunkGenerator("t3", 10), - newTagJunkGenerator("t4", 10), - newTagJunkGenerator("t5", 10), + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", 10), + newTagJunkPacketGenerator("t2", 10), + newTagJunkPacketGenerator("t3", 10), + newTagJunkPacketGenerator("t4", 10), + newTagJunkPacketGenerator("t5", 10), }, }, { name: "nameIndex error", - generators: []TagJunkGenerator{ - newTagJunkGenerator("error", 10), + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("error", 10), }, wantErr: true, errMsg: "name must be 2 character long", @@ -86,12 +86,12 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - handler := &TagJunkGeneratorHandler{} + generators := &TagJunkPacketGenerators{} for _, gen := range tt.generators { - handler.AppendGenerator(gen) + generators.AppendGenerator(gen) } - err := handler.Validate() + err := generators.Validate() if tt.wantErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errMsg) @@ -110,20 +110,20 @@ func TestTagJunkGeneratorHandlerGenerate(t *testing.T) { tests := []struct { name string - setupGenerator func() []TagJunkGenerator + setupGenerator func() []TagJunkPacketGenerator expected [][]byte }{ { name: "generate with no default junk", - setupGenerator: func() []TagJunkGenerator { - tg1 := newTagJunkGenerator("t1", 0) + setupGenerator: func() []TagJunkPacketGenerator { + tg1 := newTagJunkPacketGenerator("t1", 0) tg1.append(mockGen1) tg1.append(mockGen2) - tg2 := newTagJunkGenerator("t2", 0) + tg2 := newTagJunkPacketGenerator("t2", 0) tg2.append(mockGen2) tg2.append(mockGen1) - return []TagJunkGenerator{tg1, tg2} + return []TagJunkPacketGenerator{tg1, tg2} }, expected: [][]byte{ append(mockByte1, mockByte2...), @@ -136,13 +136,13 @@ func TestTagJunkGeneratorHandlerGenerate(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - handler := &TagJunkGeneratorHandler{} - generators := tt.setupGenerator() - for _, gen := range generators { - handler.AppendGenerator(gen) + generators := &TagJunkPacketGenerators{} + tagGenerators := tt.setupGenerator() + for _, gen := range tagGenerators { + generators.AppendGenerator(gen) } - result := handler.GeneratePackets() + result := generators.GeneratePackets() require.Equal(t, result, tt.expected) }) } diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go index 1180ac6..0359e10 100644 --- a/device/awg/tag_parser.go +++ b/device/awg/tag_parser.go @@ -54,10 +54,10 @@ func parseTag(input string) (Tag, error) { } // TODO: pointernes -func Parse(name, input string) (TagJunkGenerator, error) { +func Parse(name, input string) (TagJunkPacketGenerator, error) { inputSlice := strings.Split(input, "<") if len(inputSlice) <= 1 { - return TagJunkGenerator{}, fmt.Errorf("empty input: %s", input) + return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input) } uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) @@ -65,28 +65,28 @@ func Parse(name, input string) (TagJunkGenerator, error) { // skip byproduct of split inputSlice = inputSlice[1:] - rv := newTagJunkGenerator(name, len(inputSlice)) + rv := newTagJunkPacketGenerator(name, len(inputSlice)) for _, inputParam := range inputSlice { if len(inputParam) <= 1 { - return TagJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) + return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) } else if strings.Count(inputParam, ">") != 1 { - return TagJunkGenerator{}, fmt.Errorf("ill formated input: %s", input) + return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input) } tag, _ := parseTag(inputParam) creator, ok := generatorCreator[tag.Name] if !ok { - return TagJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) + return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) } if present, ok := uniqueTagCheck[tag.Name]; ok { if present { - return TagJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) + return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) } uniqueTagCheck[tag.Name] = true } generator, err := creator(tag.Param) if err != nil { - return TagJunkGenerator{}, fmt.Errorf("gen: %w", err) + return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err) } // TODO: handle counter tag diff --git a/device/device_test.go b/device/device_test.go index a72159a..6ede99e 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -50,48 +50,7 @@ func uapiCfg(cfg ...string) string { // genConfigs generates a pair of configs that connect to each other. // The configs use distinct, probably-usable ports. -func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { - var key1, key2 NoisePrivateKey - _, err := rand.Read(key1[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - _, err = rand.Read(key2[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - pub1, pub2 := key1.publicKey(), key2.publicKey() - - cfgs[0] = uapiCfg( - "private_key", hex.EncodeToString(key1[:]), - "listen_port", "0", - "replace_peers", "true", - "public_key", hex.EncodeToString(pub2[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.2/32", - ) - endpointCfgs[0] = uapiCfg( - "public_key", hex.EncodeToString(pub2[:]), - "endpoint", "127.0.0.1:%d", - ) - cfgs[1] = uapiCfg( - "private_key", hex.EncodeToString(key2[:]), - "listen_port", "0", - "replace_peers", "true", - "public_key", hex.EncodeToString(pub1[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.1/32", - ) - endpointCfgs[1] = uapiCfg( - "public_key", hex.EncodeToString(pub1[:]), - "endpoint", "127.0.0.1:%d", - ) - return -} - -func genAWGConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) { +func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) { var key1, key2 NoisePrivateKey _, err := rand.Read(key1[:]) if err != nil { @@ -207,11 +166,8 @@ func genTestPair( extraCfg ...string, ) (pair testPair) { var cfg, endpointCfg [2]string - if len(extraCfg) > 0 { - cfg, endpointCfg = genAWGConfigs(tb, extraCfg...) - } else { - cfg, endpointCfg = genConfigs(tb) - } + cfg, endpointCfg = genConfigs(tb, extraCfg...) + var binds [2]conn.Bind if realSocket { binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() diff --git a/device/peer.go b/device/peer.go index 9c7eaab..fdc0b86 100644 --- a/device/peer.go +++ b/device/peer.go @@ -145,7 +145,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return err } -func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error { +func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error { err := peer.SendBuffers(buffers) if err == nil { awg.PacketCounter.Add(uint64(len(buffers))) diff --git a/device/send.go b/device/send.go index 74015c2..e38e126 100644 --- a/device/send.go +++ b/device/send.go @@ -190,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { sendBuffer = append(sendBuffer, junkedHeader) - err = peer.SendBuffersCountPacket(sendBuffer) + err = peer.SendAndCountBuffers(sendBuffer) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -246,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffersCountPacket([][]byte{junkedHeader}) + err = peer.SendAndCountBuffers([][]byte{junkedHeader}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -600,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffersCountPacket(bufs) + err := peer.SendAndCountBuffers(bufs) if dataSent { peer.timersDataSent() }