feat: ready for tools implementation

This commit is contained in:
Mark Puha 2025-06-12 19:35:33 +02:00
parent e997fe1def
commit a77df8158d
9 changed files with 110 additions and 31 deletions

View file

@ -22,8 +22,8 @@ var WaitResponse = struct {
type SpecialHandshakeHandler struct { type SpecialHandshakeHandler struct {
isFirstDone bool isFirstDone bool
SpecialJunk TagJunkGeneratorHandler SpecialJunk TagJunkPacketGenerators
ControlledJunk TagJunkGeneratorHandler ControlledJunk TagJunkPacketGenerators
nextItime time.Time nextItime time.Time
ITimeout time.Duration // seconds ITimeout time.Duration // seconds
@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() { if !handler.SpecialJunk.IsDefined() {
return nil return nil
} }
// TODO: create tests // TODO: create tests
if !handler.isFirstDone { if !handler.isFirstDone {
handler.isFirstDone = true handler.isFirstDone = true
handler.nextItime = time.Now().Add(handler.ITimeout) handler.nextItime = time.Now().Add(handler.ITimeout)
return nil
}
if !handler.isTimeToSendSpecial() { return handler.SpecialJunk.GeneratePackets()
} else if !handler.isTimeToSendSpecial() {
return nil return nil
} }

View file

@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("random packet crand read: %w", err) return nil, fmt.Errorf("random packet crand read: %w", err)
} }
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil return &RandomPacketGenerator{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
} }
type TimestampGenerator struct { type TimestampGenerator struct {
@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("timeout must be less than 5000ms") return nil, fmt.Errorf("timeout must be less than 5000ms")
} }
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil return &WaitTimeoutGenerator{
waitTimeout: time.Duration(timeout) * time.Millisecond,
}, nil
} }
type PacketCounterGenerator struct { type PacketCounterGenerator struct {

View file

@ -7,12 +7,18 @@ import (
type TagJunkPacketGenerator struct { type TagJunkPacketGenerator struct {
name string name string
tagValue string
packetSize int packetSize int
generators []Generator generators []Generator
} }
func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator { func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)} return TagJunkPacketGenerator{
name: name,
tagValue: tagValue,
generators: make([]Generator, 0, size),
}
} }
func (tg *TagJunkPacketGenerator) append(generator Generator) { func (tg *TagJunkPacketGenerator) append(generator Generator) {
@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
} }
return index, nil return index, nil
} }
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
return IpcFields{
Key: tg.name,
Value: tg.tagValue,
}
}

View file

@ -79,7 +79,11 @@ func TestTagJunkGeneratorAppend(t *testing.T) {
}, },
{ {
name: "Append to non-empty generator", name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)}, initialState: TagJunkPacketGenerator{
name: "T2",
packetSize: 10,
generators: make([]Generator, 2),
},
mockSize: 7, mockSize: 7,
expectedLength: 3, // 2 existing + 1 new expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7 expectedSize: 17, // 10 + 7

View file

@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct {
DefaultJunkCount int // Jc DefaultJunkCount int // Jc
} }
func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) { func (generators *TagJunkPacketGenerators) AppendGenerator(
generator TagJunkPacketGenerator,
) {
generators.tagGenerators = append(generators.tagGenerators, generator) generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++ generators.length++
} }
@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount) var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators { for i, tagGenerator := range generators.tagGenerators {
PacketCounter.Inc()
rv = append(rv, make([]byte, tagGenerator.packetSize)) rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket()) copy(rv[i], tagGenerator.generatePacket())
PacketCounter.Inc()
} }
PacketCounter.Add(uint64(generators.DefaultJunkCount)) PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv return rv
} }
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
rv := make([]IpcFields, 0, len(tg.tagGenerators))
for _, generator := range tg.tagGenerators {
rv = append(rv, generator.IpcGetFields())
}
return rv
}

View file

@ -7,6 +7,8 @@ import (
"strings" "strings"
) )
type IpcFields struct{ Key, Value string }
type EnumTag string type EnumTag string
const ( const (
@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) {
return tag, nil return tag, nil
} }
// TODO: pointernes
func Parse(name, input string) (TagJunkPacketGenerator, error) { func Parse(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<") inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 { if len(inputSlice) <= 1 {
@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
// skip byproduct of split // skip byproduct of split
inputSlice = inputSlice[1:] inputSlice = inputSlice[1:]
rv := newTagJunkPacketGenerator(name, len(inputSlice)) rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice { for _, inputParam := range inputSlice {
if len(inputParam) <= 1 { if len(inputParam) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) return TagJunkPacketGenerator{}, fmt.Errorf(
"empty tag in input: %s",
inputSlice,
)
} else if strings.Count(inputParam, ">") != 1 { } else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input) return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
} }
@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
} }
if present, ok := uniqueTagCheck[tag.Name]; ok { if present, ok := uniqueTagCheck[tag.Name]; ok {
if present { if present {
return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) return TagJunkPacketGenerator{}, fmt.Errorf(
"tag %s needs to be unique",
tag.Name,
)
} }
uniqueTagCheck[tag.Name] = true uniqueTagCheck[tag.Name] = true
} }

View file

@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.HandshakeHandler.IsSet { if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil { if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else { } else {

View file

@ -7,19 +7,22 @@ package device
import ( import (
"bytes" "bytes"
"context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/netip" "net/netip"
"os" "os"
"os/signal"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"go.uber.org/atomic"
"github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest" "github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun" "github.com/amnezia-vpn/amneziawg-go/tun"
@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) {
}) })
} }
// Needs to be stopped with Ctrl-C
func TestAWGHandshakeDevicePing(t *testing.T) { func TestAWGHandshakeDevicePing(t *testing.T) {
t.Skip("This test is intended to be run manually, not as part of the test suite.")
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
isRunning := atomic.NewBool(true)
go func() {
<-signalContext.Done()
fmt.Println("Waiting to finish")
isRunning.Store(false)
}()
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t, true, pair := genTestPair(t, true,
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>", "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
// "h3", "123123", // "h3", "123123",
) )
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
for { for isRunning.Load() {
pair.Send(t, Ping, nil) pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
}) })
t.Run("ping 1.0.0.2", func(t *testing.T) { t.Run("ping 1.0.0.2", func(t *testing.T) {
for { for isRunning.Load() {
pair.Send(t, Pong, nil) pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }

View file

@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
} }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
// for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator { for _, field := range specialJunkIpcFields {
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) sendf("%s=%s", field.Key, field.Value)
// } }
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
for _, field := range controlledJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
} }
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
case "replace_peers": case "replace_peers":
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
} }
device.log.Verbosef("UAPI: Removing all peers") device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
@ -470,7 +478,11 @@ func (device *Device) handlePeerLine(
case "update_only": case "update_only":
// allow disabling of creation // allow disabling of creation
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set update only, invalid value: %v",
value,
)
} }
if peer.created && !peer.dummy { if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
@ -516,7 +528,11 @@ func (device *Device) handlePeerLine(
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set persistent keepalive interval: %w",
err,
)
} }
old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@ -527,7 +543,11 @@ func (device *Device) handlePeerLine(
case "replace_allowed_ips": case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to replace allowedips, invalid value: %v",
value,
)
} }
if peer.dummy { if peer.dummy {
return nil return nil
@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
return return
} }
if nextByte != '\n' { if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) err = ipcErrorf(
ipc.IpcErrorInvalid,
"trailing character in UAPI get: %q",
nextByte,
)
break break
} }
err = device.IpcGetOperation(buffered.Writer) err = device.IpcGetOperation(buffered.Writer)