From a77df8158d8e101d85c1da90384effd6008ea222 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Thu, 12 Jun 2025 19:35:33 +0200 Subject: [PATCH] feat: ready for tools implementation --- device/awg/special_handshake_handler.go | 10 ++--- device/awg/tag_generator.go | 9 ++++- device/awg/tag_junk_packet_generator.go | 19 +++++++-- device/awg/tag_junk_packet_generator_test.go | 8 +++- device/awg/tag_junk_packet_generators.go | 15 ++++++- device/awg/tag_parser.go | 15 +++++-- device/device.go | 2 +- device/device_test.go | 21 ++++++++-- device/uapi.go | 42 +++++++++++++++----- 9 files changed, 110 insertions(+), 31 deletions(-) diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go index fb83a53..40b575d 100644 --- a/device/awg/special_handshake_handler.go +++ b/device/awg/special_handshake_handler.go @@ -22,8 +22,8 @@ var WaitResponse = struct { type SpecialHandshakeHandler struct { isFirstDone bool - SpecialJunk TagJunkGeneratorHandler - ControlledJunk TagJunkGeneratorHandler + SpecialJunk TagJunkPacketGenerators + ControlledJunk TagJunkPacketGenerators nextItime time.Time ITimeout time.Duration // seconds @@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { if !handler.SpecialJunk.IsDefined() { return nil } + // TODO: create tests if !handler.isFirstDone { handler.isFirstDone = true handler.nextItime = time.Now().Add(handler.ITimeout) - return nil - } - if !handler.isTimeToSendSpecial() { + return handler.SpecialJunk.GeneratePackets() + } else if !handler.isTimeToSendSpecial() { return nil } diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go index 4aaa727..65d8004 100644 --- a/device/awg/tag_generator.go +++ b/device/awg/tag_generator.go @@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) { return nil, fmt.Errorf("random packet crand read: %w", err) } - return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil + return &RandomPacketGenerator{ + cha8Rand: v2.NewChaCha8([32]byte(buf)), + size: size, + }, nil } type TimestampGenerator struct { @@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) { return nil, fmt.Errorf("timeout must be less than 5000ms") } - return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil + return &WaitTimeoutGenerator{ + waitTimeout: time.Duration(timeout) * time.Millisecond, + }, nil } type PacketCounterGenerator struct { diff --git a/device/awg/tag_junk_packet_generator.go b/device/awg/tag_junk_packet_generator.go index 0de80d3..fdbebc8 100644 --- a/device/awg/tag_junk_packet_generator.go +++ b/device/awg/tag_junk_packet_generator.go @@ -6,13 +6,19 @@ import ( ) type TagJunkPacketGenerator struct { - name string + name string + tagValue string + packetSize int generators []Generator } -func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator { - return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)} +func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator { + return TagJunkPacketGenerator{ + name: name, + tagValue: tagValue, + generators: make([]Generator, 0, size), + } } func (tg *TagJunkPacketGenerator) append(generator Generator) { @@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) { } return index, nil } + +func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields { + return IpcFields{ + Key: tg.name, + Value: tg.tagValue, + } +} diff --git a/device/awg/tag_junk_packet_generator_test.go b/device/awg/tag_junk_packet_generator_test.go index adc4e86..d4a724f 100644 --- a/device/awg/tag_junk_packet_generator_test.go +++ b/device/awg/tag_junk_packet_generator_test.go @@ -78,8 +78,12 @@ func TestTagJunkGeneratorAppend(t *testing.T) { expectedSize: 5, }, { - name: "Append to non-empty generator", - initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)}, + name: "Append to non-empty generator", + initialState: TagJunkPacketGenerator{ + name: "T2", + packetSize: 10, + generators: make([]Generator, 2), + }, mockSize: 7, expectedLength: 3, // 2 existing + 1 new expectedSize: 17, // 10 + 7 diff --git a/device/awg/tag_junk_packet_generators.go b/device/awg/tag_junk_packet_generators.go index 5f40db2..9921eb0 100644 --- a/device/awg/tag_junk_packet_generators.go +++ b/device/awg/tag_junk_packet_generators.go @@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct { DefaultJunkCount int // Jc } -func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) { +func (generators *TagJunkPacketGenerators) AppendGenerator( + generator TagJunkPacketGenerator, +) { generators.tagGenerators = append(generators.tagGenerators, generator) generators.length++ } @@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte { var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount) for i, tagGenerator := range generators.tagGenerators { - PacketCounter.Inc() rv = append(rv, make([]byte, tagGenerator.packetSize)) copy(rv[i], tagGenerator.generatePacket()) + PacketCounter.Inc() } PacketCounter.Add(uint64(generators.DefaultJunkCount)) return rv } + +func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields { + rv := make([]IpcFields, 0, len(tg.tagGenerators)) + for _, generator := range tg.tagGenerators { + rv = append(rv, generator.IpcGetFields()) + } + + return rv +} diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go index 0359e10..2b09226 100644 --- a/device/awg/tag_parser.go +++ b/device/awg/tag_parser.go @@ -7,6 +7,8 @@ import ( "strings" ) +type IpcFields struct{ Key, Value string } + type EnumTag string const ( @@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) { return tag, nil } -// TODO: pointernes func Parse(name, input string) (TagJunkPacketGenerator, error) { inputSlice := strings.Split(input, "<") if len(inputSlice) <= 1 { @@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) { // skip byproduct of split inputSlice = inputSlice[1:] - rv := newTagJunkPacketGenerator(name, len(inputSlice)) + rv := newTagJunkPacketGenerator(name, input, len(inputSlice)) for _, inputParam := range inputSlice { if len(inputParam) <= 1 { - return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) + return TagJunkPacketGenerator{}, fmt.Errorf( + "empty tag in input: %s", + inputSlice, + ) } else if strings.Count(inputParam, ">") != 1 { return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input) } @@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) { } if present, ok := uniqueTagCheck[tag.Name]; ok { if present { - return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) + return TagJunkPacketGenerator{}, fmt.Errorf( + "tag %s needs to be unique", + tag.Name, + ) } uniqueTagCheck[tag.Name] = true } diff --git a/device/device.go b/device/device.go index 7d17ef1..a72b053 100644 --- a/device/device.go +++ b/device/device.go @@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.HandshakeHandler.IsSet { - if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil { + if err := tempAwg.HandshakeHandler.Validate(); err != nil { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) } else { diff --git a/device/device_test.go b/device/device_test.go index 6ede99e..7aec4c2 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -7,19 +7,22 @@ package device import ( "bytes" + "context" "encoding/hex" "fmt" "io" "math/rand" "net/netip" "os" + "os/signal" "runtime" "runtime/pprof" "sync" - "sync/atomic" "testing" "time" + "go.uber.org/atomic" + "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" "github.com/amnezia-vpn/amneziawg-go/tun" @@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) { }) } +// Needs to be stopped with Ctrl-C func TestAWGHandshakeDevicePing(t *testing.T) { + t.Skip("This test is intended to be run manually, not as part of the test suite.") + + signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + isRunning := atomic.NewBool(true) + go func() { + <-signalContext.Done() + fmt.Println("Waiting to finish") + isRunning.Store(false) + }() + goroutineLeakCheck(t) pair := genTestPair(t, true, "i1", "", @@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) { // "h3", "123123", ) t.Run("ping 1.0.0.1", func(t *testing.T) { - for { + for isRunning.Load() { pair.Send(t, Ping, nil) time.Sleep(2 * time.Second) } }) t.Run("ping 1.0.0.2", func(t *testing.T) { - for { + for isRunning.Load() { pair.Send(t, Pong, nil) time.Sleep(2 * time.Second) } diff --git a/device/uapi.go b/device/uapi.go index 1e95e48..9372228 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error { if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) } - - // for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator { - // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) - // } + specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() + for _, field := range specialJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } + controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields() + for _, field := range controlledJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } } for _, peer := range device.peers.keyMap { @@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) case "replace_peers": if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set replace_peers, invalid value: %v", + value, + ) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() @@ -470,7 +478,11 @@ func (device *Device) handlePeerLine( case "update_only": // allow disabling of creation if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set update only, invalid value: %v", + value, + ) } if peer.created && !peer.dummy { device.RemovePeer(peer.handshake.remoteStatic) @@ -516,7 +528,11 @@ func (device *Device) handlePeerLine( secs, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set persistent keepalive interval: %w", + err, + ) } old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) @@ -527,7 +543,11 @@ func (device *Device) handlePeerLine( case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to replace allowedips, invalid value: %v", + value, + ) } if peer.dummy { return nil @@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) { return } if nextByte != '\n' { - err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + err = ipcErrorf( + ipc.IpcErrorInvalid, + "trailing character in UAPI get: %q", + nextByte, + ) break } err = device.IpcGetOperation(buffered.Writer)