mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-02 17:52:50 +02:00
feat: ready for tools implementation
This commit is contained in:
parent
e997fe1def
commit
a77df8158d
9 changed files with 110 additions and 31 deletions
|
@ -22,8 +22,8 @@ var WaitResponse = struct {
|
|||
|
||||
type SpecialHandshakeHandler struct {
|
||||
isFirstDone bool
|
||||
SpecialJunk TagJunkGeneratorHandler
|
||||
ControlledJunk TagJunkGeneratorHandler
|
||||
SpecialJunk TagJunkPacketGenerators
|
||||
ControlledJunk TagJunkPacketGenerators
|
||||
|
||||
nextItime time.Time
|
||||
ITimeout time.Duration // seconds
|
||||
|
@ -46,14 +46,14 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
|||
if !handler.SpecialJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: create tests
|
||||
if !handler.isFirstDone {
|
||||
handler.isFirstDone = true
|
||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !handler.isTimeToSendSpecial() {
|
||||
return handler.SpecialJunk.GeneratePackets()
|
||||
} else if !handler.isTimeToSendSpecial() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,10 @@ func newRandomPacketGenerator(param string) (Generator, error) {
|
|||
return nil, fmt.Errorf("random packet crand read: %w", err)
|
||||
}
|
||||
|
||||
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
|
||||
return &RandomPacketGenerator{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
|
@ -137,7 +140,9 @@ func newWaitTimeoutGenerator(param string) (Generator, error) {
|
|||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||
}
|
||||
|
||||
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
|
||||
return &WaitTimeoutGenerator{
|
||||
waitTimeout: time.Duration(timeout) * time.Millisecond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
|
|
|
@ -6,13 +6,19 @@ import (
|
|||
)
|
||||
|
||||
type TagJunkPacketGenerator struct {
|
||||
name string
|
||||
name string
|
||||
tagValue string
|
||||
|
||||
packetSize int
|
||||
generators []Generator
|
||||
}
|
||||
|
||||
func newTagJunkPacketGenerator(name string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{name: name, generators: make([]Generator, 0, size)}
|
||||
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{
|
||||
name: name,
|
||||
tagValue: tagValue,
|
||||
generators: make([]Generator, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||
|
@ -44,3 +50,10 @@ func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
|||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||
return IpcFields{
|
||||
Key: tg.name,
|
||||
Value: tg.tagValue,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,8 +78,12 @@ func TestTagJunkGeneratorAppend(t *testing.T) {
|
|||
expectedSize: 5,
|
||||
},
|
||||
{
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{name: "T2", packetSize: 10, generators: make([]Generator, 2)},
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 10,
|
||||
generators: make([]Generator, 2),
|
||||
},
|
||||
mockSize: 7,
|
||||
expectedLength: 3, // 2 existing + 1 new
|
||||
expectedSize: 17, // 10 + 7
|
||||
|
|
|
@ -8,7 +8,9 @@ type TagJunkPacketGenerators struct {
|
|||
DefaultJunkCount int // Jc
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(generator TagJunkPacketGenerator) {
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||
generator TagJunkPacketGenerator,
|
||||
) {
|
||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||
generators.length++
|
||||
}
|
||||
|
@ -45,11 +47,20 @@ func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
|||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range generators.tagGenerators {
|
||||
PacketCounter.Inc()
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
PacketCounter.Inc()
|
||||
}
|
||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||
for _, generator := range tg.tagGenerators {
|
||||
rv = append(rv, generator.IpcGetFields())
|
||||
}
|
||||
|
||||
return rv
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
type IpcFields struct{ Key, Value string }
|
||||
|
||||
type EnumTag string
|
||||
|
||||
const (
|
||||
|
@ -53,7 +55,6 @@ func parseTag(input string) (Tag, error) {
|
|||
return tag, nil
|
||||
}
|
||||
|
||||
// TODO: pointernes
|
||||
func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||
inputSlice := strings.Split(input, "<")
|
||||
if len(inputSlice) <= 1 {
|
||||
|
@ -65,10 +66,13 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
|||
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := newTagJunkPacketGenerator(name, len(inputSlice))
|
||||
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"empty tag in input: %s",
|
||||
inputSlice,
|
||||
)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
@ -80,7 +84,10 @@ func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
|||
}
|
||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||
if present {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"tag %s needs to be unique",
|
||||
tag.Name,
|
||||
)
|
||||
}
|
||||
uniqueTagCheck[tag.Name] = true
|
||||
}
|
||||
|
|
|
@ -773,7 +773,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||
} else {
|
||||
|
|
|
@ -7,19 +7,22 @@ package device
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
|
@ -242,7 +245,19 @@ func TestAWGDevicePing(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Needs to be stopped with Ctrl-C
|
||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||
t.Skip("This test is intended to be run manually, not as part of the test suite.")
|
||||
|
||||
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
isRunning := atomic.NewBool(true)
|
||||
go func() {
|
||||
<-signalContext.Done()
|
||||
fmt.Println("Waiting to finish")
|
||||
isRunning.Store(false)
|
||||
}()
|
||||
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
|
@ -262,13 +277,13 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
|
|||
// "h3", "123123",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
for {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Ping, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
for {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Pong, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
|
|
@ -126,10 +126,14 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||
}
|
||||
|
||||
// for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator {
|
||||
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||
// }
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
|
||||
for _, field := range controlledJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
|
@ -283,7 +287,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set replace_peers, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
@ -470,7 +478,11 @@ func (device *Device) handlePeerLine(
|
|||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set update only, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
|
@ -516,7 +528,11 @@ func (device *Device) handlePeerLine(
|
|||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set persistent keepalive interval: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
|
@ -527,7 +543,11 @@ func (device *Device) handlePeerLine(
|
|||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to replace allowedips, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
|
@ -595,7 +615,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"trailing character in UAPI get: %q",
|
||||
nextByte,
|
||||
)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
|
|
Loading…
Add table
Reference in a new issue