From a77df8158d8e101d85c1da90384effd6008ea222 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 12 Jun 2025 19:35:33 +0200
Subject: [PATCH] feat: ready for tools implementation
---
device/awg/special_handshake_handler.go | 10 ++---
device/awg/tag_generator.go | 9 ++++-
device/awg/tag_junk_packet_generator.go | 19 +++++++--
device/awg/tag_junk_packet_generator_test.go | 8 +++-
device/awg/tag_junk_packet_generators.go | 15 ++++++-
device/awg/tag_parser.go | 15 +++++--
device/device.go | 2 +-
device/device_test.go | 21 ++++++++--
device/uapi.go | 42 +++++++++++++++-----
9 files changed, 110 insertions(+), 31 deletions(-)
diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go
index fb83a53..40b575d 100644
--- a/device/awg/special_handshake_handler.go
+++ b/device/awg/special_handshake_handler.go
@@ -22,8 +22,8 @@ var WaitResponse = struct {
type SpecialHandshakeHandler struct {
isFirstDone bool
- SpecialJunk TagJunkGeneratorHandler
- ControlledJunk TagJunkGeneratorHandler
+ SpecialJunk TagJunkPacketGenerators
+ ControlledJunk TagJunkPacketGenerators
nextItime time.Time
ITimeout time.Duration // seconds
@@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() {
return nil
}
+
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
handler.nextItime = time.Now().Add(handler.ITimeout)
- return nil
- }
- if !handler.isTimeToSendSpecial() {
+ return handler.SpecialJunk.GeneratePackets()
+ } else if !handler.isTimeToSendSpecial() {
return nil
}
diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go
index 4aaa727..65d8004 100644
--- a/device/awg/tag_generator.go
+++ b/device/awg/tag_generator.go
@@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("random packet crand read: %w", err)
}
- return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
+ return &RandomPacketGenerator{
+ cha8Rand: v2.NewChaCha8([32]byte(buf)),
+ size: size,
+ }, nil
}
type TimestampGenerator struct {
@@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
- return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
+ return &WaitTimeoutGenerator{
+ waitTimeout: time.Duration(timeout) * time.Millisecond,
+ }, nil
}
type PacketCounterGenerator struct {
diff --git a/device/awg/tag_junk_packet_generator.go b/device/awg/tag_junk_packet_generator.go
index 0de80d3..fdbebc8 100644
--- a/device/awg/tag_junk_packet_generator.go
+++ b/device/awg/tag_junk_packet_generator.go
@@ -6,13 +6,19 @@ import (
)
type TagJunkPacketGenerator struct {
- name string
+ name string
+ tagValue string
+
packetSize int
generators []Generator
}
-func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator {
- return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)}
+func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
+ return TagJunkPacketGenerator{
+ name: name,
+ tagValue: tagValue,
+ generators: make([]Generator, 0, size),
+ }
}
func (tg *TagJunkPacketGenerator) append(generator Generator) {
@@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
}
return index, nil
}
+
+func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
+ return IpcFields{
+ Key: tg.name,
+ Value: tg.tagValue,
+ }
+}
diff --git a/device/awg/tag_junk_packet_generator_test.go b/device/awg/tag_junk_packet_generator_test.go
index adc4e86..d4a724f 100644
--- a/device/awg/tag_junk_packet_generator_test.go
+++ b/device/awg/tag_junk_packet_generator_test.go
@@ -78,8 +78,12 @@ func TestTagJunkGeneratorAppend(t *testing.T) {
expectedSize: 5,
},
{
- name: "Append to non-empty generator",
- initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)},
+ name: "Append to non-empty generator",
+ initialState: TagJunkPacketGenerator{
+ name: "T2",
+ packetSize: 10,
+ generators: make([]Generator, 2),
+ },
mockSize: 7,
expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7
diff --git a/device/awg/tag_junk_packet_generators.go b/device/awg/tag_junk_packet_generators.go
index 5f40db2..9921eb0 100644
--- a/device/awg/tag_junk_packet_generators.go
+++ b/device/awg/tag_junk_packet_generators.go
@@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct {
DefaultJunkCount int // Jc
}
-func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) {
+func (generators *TagJunkPacketGenerators) AppendGenerator(
+ generator TagJunkPacketGenerator,
+) {
generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++
}
@@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators {
- PacketCounter.Inc()
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
+ PacketCounter.Inc()
}
PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv
}
+
+func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
+ rv := make([]IpcFields, 0, len(tg.tagGenerators))
+ for _, generator := range tg.tagGenerators {
+ rv = append(rv, generator.IpcGetFields())
+ }
+
+ return rv
+}
diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go
index 0359e10..2b09226 100644
--- a/device/awg/tag_parser.go
+++ b/device/awg/tag_parser.go
@@ -7,6 +7,8 @@ import (
"strings"
)
+type IpcFields struct{ Key, Value string }
+
type EnumTag string
const (
@@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) {
return tag, nil
}
-// TODO: pointernes
func Parse(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
@@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
// skip byproduct of split
inputSlice = inputSlice[1:]
- rv := newTagJunkPacketGenerator(name, len(inputSlice))
+ rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice {
if len(inputParam) <= 1 {
- return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
+ return TagJunkPacketGenerator{}, fmt.Errorf(
+ "empty tag in input: %s",
+ inputSlice,
+ )
} else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
}
@@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
}
if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
- return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
+ return TagJunkPacketGenerator{}, fmt.Errorf(
+ "tag %s needs to be unique",
+ tag.Name,
+ )
}
uniqueTagCheck[tag.Name] = true
}
diff --git a/device/device.go b/device/device.go
index 7d17ef1..a72b053 100644
--- a/device/device.go
+++ b/device/device.go
@@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.HandshakeHandler.IsSet {
- if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil {
+ if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
diff --git a/device/device_test.go b/device/device_test.go
index 6ede99e..7aec4c2 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -7,19 +7,22 @@ package device
import (
"bytes"
+ "context"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
+ "os/signal"
"runtime"
"runtime/pprof"
"sync"
- "sync/atomic"
"testing"
"time"
+ "go.uber.org/atomic"
+
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
@@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) {
})
}
+// Needs to be stopped with Ctrl-C
func TestAWGHandshakeDevicePing(t *testing.T) {
+ t.Skip("This test is intended to be run manually, not as part of the test suite.")
+
+ signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
+ defer cancel()
+ isRunning := atomic.NewBool(true)
+ go func() {
+ <-signalContext.Done()
+ fmt.Println("Waiting to finish")
+ isRunning.Store(false)
+ }()
+
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"i1", "",
@@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
- for {
+ for isRunning.Load() {
pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second)
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
- for {
+ for isRunning.Load() {
pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second)
}
diff --git a/device/uapi.go b/device/uapi.go
index 1e95e48..9372228 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
}
-
- // for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator {
- // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
- // }
+ specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
+ for _, field := range specialJunkIpcFields {
+ sendf("%s=%s", field.Key, field.Value)
+ }
+ controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
+ for _, field := range controlledJunkIpcFields {
+ sendf("%s=%s", field.Key, field.Value)
+ }
}
for _, peer := range device.peers.keyMap {
@@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
case "replace_peers":
if value != "true" {
- return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
+ return ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "failed to set replace_peers, invalid value: %v",
+ value,
+ )
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
@@ -470,7 +478,11 @@ func (device *Device) handlePeerLine(
case "update_only":
// allow disabling of creation
if value != "true" {
- return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
+ return ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "failed to set update only, invalid value: %v",
+ value,
+ )
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
@@ -516,7 +528,11 @@ func (device *Device) handlePeerLine(
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
+ return ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "failed to set persistent keepalive interval: %w",
+ err,
+ )
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@@ -527,7 +543,11 @@ func (device *Device) handlePeerLine(
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
- return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
+ return ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "failed to replace allowedips, invalid value: %v",
+ value,
+ )
}
if peer.dummy {
return nil
@@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
return
}
if nextByte != '\n' {
- err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "trailing character in UAPI get: %q",
+ nextByte,
+ )
break
}
err = device.IpcGetOperation(buffered.Writer)