feat: ready for tools implementation

This commit is contained in:
Mark Puha 2025-06-12 19:35:33 +02:00
parent e997fe1def
commit a77df8158d
9 changed files with 110 additions and 31 deletions

View file

@ -22,8 +22,8 @@ var WaitResponse = struct {
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkGeneratorHandler
ControlledJunk TagJunkGeneratorHandler
SpecialJunk TagJunkPacketGenerators
ControlledJunk TagJunkPacketGenerators
nextItime time.Time
ITimeout time.Duration // seconds
@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() {
return nil
}
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
handler.nextItime = time.Now().Add(handler.ITimeout)
return nil
}
if !handler.isTimeToSendSpecial() {
return handler.SpecialJunk.GeneratePackets()
} else if !handler.isTimeToSendSpecial() {
return nil
}

View file

@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("random packet crand read: %w", err)
}
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
return &RandomPacketGenerator{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
}
type TimestampGenerator struct {
@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
return &WaitTimeoutGenerator{
waitTimeout: time.Duration(timeout) * time.Millisecond,
}, nil
}
type PacketCounterGenerator struct {

View file

@ -6,13 +6,19 @@ import (
)
type TagJunkPacketGenerator struct {
name string
name string
tagValue string
packetSize int
generators []Generator
}
func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)}
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{
name: name,
tagValue: tagValue,
generators: make([]Generator, 0, size),
}
}
func (tg *TagJunkPacketGenerator) append(generator Generator) {
@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
}
return index, nil
}
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
return IpcFields{
Key: tg.name,
Value: tg.tagValue,
}
}

View file

@ -78,8 +78,12 @@ func TestTagJunkGeneratorAppend(t *testing.T) {
expectedSize: 5,
},
{
name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)},
name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{
name: "T2",
packetSize: 10,
generators: make([]Generator, 2),
},
mockSize: 7,
expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7

View file

@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct {
DefaultJunkCount int // Jc
}
func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) {
func (generators *TagJunkPacketGenerators) AppendGenerator(
generator TagJunkPacketGenerator,
) {
generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++
}
@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators {
PacketCounter.Inc()
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
PacketCounter.Inc()
}
PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv
}
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
rv := make([]IpcFields, 0, len(tg.tagGenerators))
for _, generator := range tg.tagGenerators {
rv = append(rv, generator.IpcGetFields())
}
return rv
}

View file

@ -7,6 +7,8 @@ import (
"strings"
)
type IpcFields struct{ Key, Value string }
type EnumTag string
const (
@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) {
return tag, nil
}
// TODO: pointernes
func Parse(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
// skip byproduct of split
inputSlice = inputSlice[1:]
rv := newTagJunkPacketGenerator(name, len(inputSlice))
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice {
if len(inputParam) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
return TagJunkPacketGenerator{}, fmt.Errorf(
"empty tag in input: %s",
inputSlice,
)
} else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
}
@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
}
if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
return TagJunkPacketGenerator{}, fmt.Errorf(
"tag %s needs to be unique",
tag.Name,
)
}
uniqueTagCheck[tag.Name] = true
}

View file

@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {

View file

@ -7,19 +7,22 @@ package device
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"
"testing"
"time"
"go.uber.org/atomic"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) {
})
}
// Needs to be stopped with Ctrl-C
func TestAWGHandshakeDevicePing(t *testing.T) {
t.Skip("This test is intended to be run manually, not as part of the test suite.")
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
isRunning := atomic.NewBool(true)
go func() {
<-signalContext.Done()
fmt.Println("Waiting to finish")
isRunning.Store(false)
}()
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
for {
for isRunning.Load() {
pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second)
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
for {
for isRunning.Load() {
pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second)
}

View file

@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
}
// for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator {
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
// }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
for _, field := range controlledJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
}
for _, peer := range device.peers.keyMap {
@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
case "replace_peers":
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
@ -470,7 +478,11 @@ func (device *Device) handlePeerLine(
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set update only, invalid value: %v",
value,
)
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
@ -516,7 +528,11 @@ func (device *Device) handlePeerLine(
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set persistent keepalive interval: %w",
err,
)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@ -527,7 +543,11 @@ func (device *Device) handlePeerLine(
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to replace allowedips, invalid value: %v",
value,
)
}
if peer.dummy {
return nil
@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
return
}
if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
err = ipcErrorf(
ipc.IpcErrorInvalid,
"trailing character in UAPI get: %q",
nextByte,
)
break
}
err = device.IpcGetOperation(buffered.Writer)