mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-03 02:02:50 +02:00
feat: ready for tools implementation
This commit is contained in:
parent
e997fe1def
commit
a77df8158d
9 changed files with 110 additions and 31 deletions
|
@ -22,8 +22,8 @@ var WaitResponse = struct {
|
||||||
|
|
||||||
type SpecialHandshakeHandler struct {
|
type SpecialHandshakeHandler struct {
|
||||||
isFirstDone bool
|
isFirstDone bool
|
||||||
SpecialJunk TagJunkGeneratorHandler
|
SpecialJunk TagJunkPacketGenerators
|
||||||
ControlledJunk TagJunkGeneratorHandler
|
ControlledJunk TagJunkPacketGenerators
|
||||||
|
|
||||||
nextItime time.Time
|
nextItime time.Time
|
||||||
ITimeout time.Duration // seconds
|
ITimeout time.Duration // seconds
|
||||||
|
@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||||
if !handler.SpecialJunk.IsDefined() {
|
if !handler.SpecialJunk.IsDefined() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: create tests
|
// TODO: create tests
|
||||||
if !handler.isFirstDone {
|
if !handler.isFirstDone {
|
||||||
handler.isFirstDone = true
|
handler.isFirstDone = true
|
||||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !handler.isTimeToSendSpecial() {
|
return handler.SpecialJunk.GeneratePackets()
|
||||||
|
} else if !handler.isTimeToSendSpecial() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) {
|
||||||
return nil, fmt.Errorf("random packet crand read: %w", err)
|
return nil, fmt.Errorf("random packet crand read: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
|
return &RandomPacketGenerator{
|
||||||
|
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||||
|
size: size,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TimestampGenerator struct {
|
type TimestampGenerator struct {
|
||||||
|
@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
|
return &WaitTimeoutGenerator{
|
||||||
|
waitTimeout: time.Duration(timeout) * time.Millisecond,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketCounterGenerator struct {
|
type PacketCounterGenerator struct {
|
||||||
|
|
|
@ -6,13 +6,19 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type TagJunkPacketGenerator struct {
|
type TagJunkPacketGenerator struct {
|
||||||
name string
|
name string
|
||||||
|
tagValue string
|
||||||
|
|
||||||
packetSize int
|
packetSize int
|
||||||
generators []Generator
|
generators []Generator
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator {
|
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||||
return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)}
|
return TagJunkPacketGenerator{
|
||||||
|
name: name,
|
||||||
|
tagValue: tagValue,
|
||||||
|
generators: make([]Generator, 0, size),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||||
|
@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
||||||
}
|
}
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||||
|
return IpcFields{
|
||||||
|
Key: tg.name,
|
||||||
|
Value: tg.tagValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -78,8 +78,12 @@ func TestTagJunkGeneratorAppend(t *testing.T) {
|
||||||
expectedSize: 5,
|
expectedSize: 5,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Append to non-empty generator",
|
name: "Append to non-empty generator",
|
||||||
initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)},
|
initialState: TagJunkPacketGenerator{
|
||||||
|
name: "T2",
|
||||||
|
packetSize: 10,
|
||||||
|
generators: make([]Generator, 2),
|
||||||
|
},
|
||||||
mockSize: 7,
|
mockSize: 7,
|
||||||
expectedLength: 3, // 2 existing + 1 new
|
expectedLength: 3, // 2 existing + 1 new
|
||||||
expectedSize: 17, // 10 + 7
|
expectedSize: 17, // 10 + 7
|
||||||
|
|
|
@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct {
|
||||||
DefaultJunkCount int // Jc
|
DefaultJunkCount int // Jc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) {
|
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||||
|
generator TagJunkPacketGenerator,
|
||||||
|
) {
|
||||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||||
generators.length++
|
generators.length++
|
||||||
}
|
}
|
||||||
|
@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
||||||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||||
|
|
||||||
for i, tagGenerator := range generators.tagGenerators {
|
for i, tagGenerator := range generators.tagGenerators {
|
||||||
PacketCounter.Inc()
|
|
||||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||||
copy(rv[i], tagGenerator.generatePacket())
|
copy(rv[i], tagGenerator.generatePacket())
|
||||||
|
PacketCounter.Inc()
|
||||||
}
|
}
|
||||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||||
|
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||||
|
for _, generator := range tg.tagGenerators {
|
||||||
|
rv = append(rv, generator.IpcGetFields())
|
||||||
|
}
|
||||||
|
|
||||||
|
return rv
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type IpcFields struct{ Key, Value string }
|
||||||
|
|
||||||
type EnumTag string
|
type EnumTag string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) {
|
||||||
return tag, nil
|
return tag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: pointernes
|
|
||||||
func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||||
inputSlice := strings.Split(input, "<")
|
inputSlice := strings.Split(input, "<")
|
||||||
if len(inputSlice) <= 1 {
|
if len(inputSlice) <= 1 {
|
||||||
|
@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||||
|
|
||||||
// skip byproduct of split
|
// skip byproduct of split
|
||||||
inputSlice = inputSlice[1:]
|
inputSlice = inputSlice[1:]
|
||||||
rv := newTagJunkPacketGenerator(name, len(inputSlice))
|
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||||
for _, inputParam := range inputSlice {
|
for _, inputParam := range inputSlice {
|
||||||
if len(inputParam) <= 1 {
|
if len(inputParam) <= 1 {
|
||||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
|
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||||
|
"empty tag in input: %s",
|
||||||
|
inputSlice,
|
||||||
|
)
|
||||||
} else if strings.Count(inputParam, ">") != 1 {
|
} else if strings.Count(inputParam, ">") != 1 {
|
||||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||||
}
|
}
|
||||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||||
if present {
|
if present {
|
||||||
return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
|
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||||
|
"tag %s needs to be unique",
|
||||||
|
tag.Name,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
uniqueTagCheck[tag.Name] = true
|
uniqueTagCheck[tag.Name] = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempAwg.HandshakeHandler.IsSet {
|
if tempAwg.HandshakeHandler.IsSet {
|
||||||
if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil {
|
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||||
errs = append(errs, ipcErrorf(
|
errs = append(errs, ipcErrorf(
|
||||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -7,19 +7,22 @@ package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Needs to be stopped with Ctrl-C
|
||||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||||
|
t.Skip("This test is intended to be run manually, not as part of the test suite.")
|
||||||
|
|
||||||
|
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||||
|
defer cancel()
|
||||||
|
isRunning := atomic.NewBool(true)
|
||||||
|
go func() {
|
||||||
|
<-signalContext.Done()
|
||||||
|
fmt.Println("Waiting to finish")
|
||||||
|
isRunning.Store(false)
|
||||||
|
}()
|
||||||
|
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true,
|
pair := genTestPair(t, true,
|
||||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||||
|
@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||||
// "h3", "123123",
|
// "h3", "123123",
|
||||||
)
|
)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
for {
|
for isRunning.Load() {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
for {
|
for isRunning.Load() {
|
||||||
pair.Send(t, Pong, nil)
|
pair.Send(t, Pong, nil)
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
||||||
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||||
}
|
}
|
||||||
|
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||||
// for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator {
|
for _, field := range specialJunkIpcFields {
|
||||||
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
sendf("%s=%s", field.Key, field.Value)
|
||||||
// }
|
}
|
||||||
|
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
|
||||||
|
for _, field := range controlledJunkIpcFields {
|
||||||
|
sendf("%s=%s", field.Key, field.Value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
|
@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set replace_peers, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Removing all peers")
|
device.log.Verbosef("UAPI: Removing all peers")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
@ -470,7 +478,11 @@ func (device *Device) handlePeerLine(
|
||||||
case "update_only":
|
case "update_only":
|
||||||
// allow disabling of creation
|
// allow disabling of creation
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set update only, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if peer.created && !peer.dummy {
|
if peer.created && !peer.dummy {
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
@ -516,7 +528,11 @@ func (device *Device) handlePeerLine(
|
||||||
|
|
||||||
secs, err := strconv.ParseUint(value, 10, 16)
|
secs, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set persistent keepalive interval: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||||
|
@ -527,7 +543,11 @@ func (device *Device) handlePeerLine(
|
||||||
case "replace_allowed_ips":
|
case "replace_allowed_ips":
|
||||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to replace allowedips, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
|
@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if nextByte != '\n' {
|
if nextByte != '\n' {
|
||||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"trailing character in UAPI get: %q",
|
||||||
|
nextByte,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
err = device.IpcGetOperation(buffered.Writer)
|
err = device.IpcGetOperation(buffered.Writer)
|
||||||
|
|
Loading…
Add table
Reference in a new issue