mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-02 09:52:49 +02:00
feat: test
This commit is contained in:
parent
c66702372d
commit
f6c385f6a7
14 changed files with 240 additions and 97 deletions
|
@ -22,18 +22,18 @@ func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
// Should be called with aSecMux RLocked
|
||||||
func (jc *junkCreator) CreateJunkPackets(junks [][]byte) error {
|
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
|
||||||
if jc.aSecCfg.JunkPacketCount == 0 {
|
if jc.aSecCfg.JunkPacketCount == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range jc.aSecCfg.JunkPacketCount {
|
for range jc.aSecCfg.JunkPacketCount {
|
||||||
packetSize := jc.randomPacketSize()
|
packetSize := jc.randomPacketSize()
|
||||||
junk, err := jc.randomJunkWithSize(packetSize)
|
junk, err := jc.randomJunkWithSize(packetSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create junk packet: %v", err)
|
return fmt.Errorf("create junk packet: %v", err)
|
||||||
}
|
}
|
||||||
junks[i] = junk
|
*junks = append(*junks, junk)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
got := make([][]byte, jc.aSecCfg.JunkPacketCount)
|
got := make([][]byte, jc.aSecCfg.JunkPacketCount)
|
||||||
err := jc.CreateJunkPackets(got)
|
err := jc.CreateJunkPackets(&got)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf(
|
t.Errorf(
|
||||||
"junkCreator.createJunkPackets() = %v; failed",
|
"junkCreator.createJunkPackets() = %v; failed",
|
||||||
|
|
|
@ -3,18 +3,22 @@ package awg
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: atomic/ and better way to use this
|
||||||
|
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||||
|
|
||||||
type SpecialHandshakeHandler struct {
|
type SpecialHandshakeHandler struct {
|
||||||
|
isFirstDone bool
|
||||||
SpecialJunk TagJunkGeneratorHandler
|
SpecialJunk TagJunkGeneratorHandler
|
||||||
ControlledJunk TagJunkGeneratorHandler
|
ControlledJunk TagJunkGeneratorHandler
|
||||||
|
|
||||||
nextItime time.Time
|
nextItime time.Time
|
||||||
ITimeout time.Duration // seconds
|
ITimeout time.Duration // seconds
|
||||||
|
|
||||||
// TODO: maybe atomic?
|
IsSet bool
|
||||||
PacketCounter uint64
|
|
||||||
IsSet bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
func (handler *SpecialHandshakeHandler) Validate() error {
|
||||||
|
@ -29,13 +33,21 @@ func (handler *SpecialHandshakeHandler) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||||
// TODO: distiungish between first and the rest of the packets
|
if !handler.SpecialJunk.IsDefined() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// TODO: create tests
|
||||||
|
if !handler.isFirstDone {
|
||||||
|
handler.isFirstDone = true
|
||||||
|
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if !handler.isTimeToSendSpecial() {
|
if !handler.isTimeToSendSpecial() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rv := handler.SpecialJunk.GeneratePackets()
|
rv := handler.SpecialJunk.GeneratePackets()
|
||||||
|
|
||||||
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
@ -45,6 +57,10 @@ func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
|
||||||
return time.Now().After(handler.nextItime)
|
return time.Now().After(handler.nextItime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte {
|
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
|
||||||
|
if !handler.ControlledJunk.IsDefined() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return handler.ControlledJunk.GeneratePackets()
|
return handler.ControlledJunk.GeneratePackets()
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v2 "math/rand/v2"
|
v2 "math/rand/v2"
|
||||||
|
// "go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Generator interface {
|
type Generator interface {
|
||||||
|
@ -33,9 +34,8 @@ func (bg *BytesGenerator) Size() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBytesGenerator(param string) (Generator, error) {
|
func newBytesGenerator(param string) (Generator, error) {
|
||||||
isNotHex := !strings.HasPrefix(param, "0x") ||
|
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
|
||||||
!strings.HasPrefix(param, "0x") && !isHexString(param)
|
if !hasPrefix {
|
||||||
if isNotHex {
|
|
||||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,17 +47,6 @@ func newBytesGenerator(param string) (Generator, error) {
|
||||||
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isHexString(s string) bool {
|
|
||||||
for _, char := range s {
|
|
||||||
if !((char >= '0' && char <= '9') ||
|
|
||||||
(char >= 'a' && char <= 'f') ||
|
|
||||||
(char >= 'A' && char <= 'F')) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(s) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func hexToBytes(hexStr string) ([]byte, error) {
|
func hexToBytes(hexStr string) ([]byte, error) {
|
||||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||||
|
@ -110,6 +99,7 @@ type TimestampGenerator struct {
|
||||||
func (tg *TimestampGenerator) Generate() []byte {
|
func (tg *TimestampGenerator) Generate() []byte {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||||
|
fmt.Printf("timestamp: %v\n", buf)
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,6 +120,7 @@ type WaitTimeoutGenerator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
||||||
|
fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
|
||||||
time.Sleep(wtg.waitTimeout)
|
time.Sleep(wtg.waitTimeout)
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
@ -139,14 +130,39 @@ func (wtg *WaitTimeoutGenerator) Size() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||||
size, err := strconv.Atoi(param)
|
t, err := strconv.Atoi(param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("timeout parse int: %w", err)
|
return nil, fmt.Errorf("timeout parse int: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if size > 5000 {
|
if t > 5000 {
|
||||||
return nil, fmt.Errorf("timeout size must be less than 5000ms")
|
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &WaitTimeoutGenerator{}, nil
|
return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketCounterGenerator struct {
|
||||||
|
// counter *atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketCounterGenerator) Generate() []byte {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
// TODO: better way to handle counter tag
|
||||||
|
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||||
|
fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketCounterGenerator) Size() int {
|
||||||
|
return 8
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacketCounterGenerator(param string) (Generator, error) {
|
||||||
|
if len(param) != 0 {
|
||||||
|
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||||
|
}
|
||||||
|
|
||||||
|
// return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
|
||||||
|
return &PacketCounterGenerator{}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package awg
|
package awg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -32,7 +33,14 @@ func Test_newBytesGenerator(t *testing.T) {
|
||||||
wantErr: fmt.Errorf("not correct hex"),
|
wantErr: fmt.Errorf("not correct hex"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "not only hex value",
|
name: "not only hex value with X",
|
||||||
|
args: args{
|
||||||
|
param: "0X12345q",
|
||||||
|
},
|
||||||
|
wantErr: fmt.Errorf("not correct hex"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not only hex value with x",
|
||||||
args: args{
|
args: args{
|
||||||
param: "0x12345q",
|
param: "0x12345q",
|
||||||
},
|
},
|
||||||
|
@ -127,3 +135,55 @@ func Test_newRandomPacketGenerator(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPacketCounterGenerator(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
param string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid empty param",
|
||||||
|
param: "",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid non-empty param",
|
||||||
|
param: "anything",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc // capture range variable
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gen, err := newPacketCounterGenerator(tc.param)
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 8, gen.Size())
|
||||||
|
|
||||||
|
// Reset counter to known value for test
|
||||||
|
initialCount := uint64(42)
|
||||||
|
PacketCounter.Store(initialCount)
|
||||||
|
|
||||||
|
output := gen.Generate()
|
||||||
|
require.Equal(t, 8, len(output))
|
||||||
|
|
||||||
|
// Verify counter value in output
|
||||||
|
counterValue := binary.BigEndian.Uint64(output)
|
||||||
|
require.Equal(t, initialCount, counterValue)
|
||||||
|
|
||||||
|
// Increment counter and verify change
|
||||||
|
PacketCounter.Add(1)
|
||||||
|
output = gen.Generate()
|
||||||
|
counterValue = binary.BigEndian.Uint64(output)
|
||||||
|
require.Equal(t, initialCount+1, counterValue)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type TagJunkGenerator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTagJunkGenerator(name string, size int) TagJunkGenerator {
|
func newTagJunkGenerator(name string, size int) TagJunkGenerator {
|
||||||
return TagJunkGenerator{name: name, generators: make([]Generator, size)}
|
return TagJunkGenerator{name: name, generators: make([]Generator, 0, size)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tg *TagJunkGenerator) append(generator Generator) {
|
func (tg *TagJunkGenerator) append(generator Generator) {
|
||||||
|
@ -40,7 +40,7 @@ func (tg *TagJunkGenerator) nameIndex() (int, error) {
|
||||||
|
|
||||||
index, err := strconv.Atoi(tg.name[1:2])
|
index, err := strconv.Atoi(tg.name[1:2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("name should be 2 char long: %w", err)
|
return 0, fmt.Errorf("name 2 char should be an int %w", err)
|
||||||
}
|
}
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,23 +3,26 @@ package awg
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
type TagJunkGeneratorHandler struct {
|
type TagJunkGeneratorHandler struct {
|
||||||
generators []TagJunkGenerator
|
tagGenerators []TagJunkGenerator
|
||||||
length int
|
length int
|
||||||
// Jc
|
DefaultJunkCount int // Jc
|
||||||
DefaultJunkCount int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) {
|
func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) {
|
||||||
handler.generators = append(handler.generators, generators)
|
handler.tagGenerators = append(handler.tagGenerators, generators)
|
||||||
handler.length++
|
handler.length++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (handler *TagJunkGeneratorHandler) IsDefined() bool {
|
||||||
|
return len(handler.tagGenerators) > 0
|
||||||
|
}
|
||||||
|
|
||||||
// validate that packets were defined consecutively
|
// validate that packets were defined consecutively
|
||||||
func (handler *TagJunkGeneratorHandler) Validate() error {
|
func (handler *TagJunkGeneratorHandler) Validate() error {
|
||||||
seen := make([]bool, len(handler.generators))
|
seen := make([]bool, len(handler.tagGenerators))
|
||||||
for _, generator := range handler.generators {
|
for _, generator := range handler.tagGenerators {
|
||||||
index, err := generator.nameIndex()
|
index, err := generator.nameIndex()
|
||||||
if index > len(handler.generators) {
|
if index > len(handler.tagGenerators) {
|
||||||
return fmt.Errorf("junk packet index should be consecutive")
|
return fmt.Errorf("junk packet index should be consecutive")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -39,10 +42,10 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
|
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
|
||||||
var rv = make([][]byte, handler.length+handler.DefaultJunkCount)
|
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
|
||||||
for i, generator := range handler.generators {
|
for i, tagGenerator := range handler.tagGenerators {
|
||||||
rv[i] = make([]byte, generator.packetSize)
|
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||||
copy(rv[i], generator.generatePacket())
|
copy(rv[i], tagGenerator.generatePacket())
|
||||||
}
|
}
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
|
|
@ -8,8 +8,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
generator TagJunkGenerator
|
generator TagJunkGenerator
|
||||||
|
@ -28,20 +26,18 @@ func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||||
|
|
||||||
// Initial length should be 0
|
// Initial length should be 0
|
||||||
require.Equal(t, 0, handler.length)
|
require.Equal(t, 0, handler.length)
|
||||||
require.Empty(t, handler.generators)
|
require.Empty(t, handler.tagGenerators)
|
||||||
|
|
||||||
// After append, length should be 1 and generator should be added
|
// After append, length should be 1 and generator should be added
|
||||||
handler.AppendGenerator(tt.generator)
|
handler.AppendGenerator(tt.generator)
|
||||||
require.Equal(t, 1, handler.length)
|
require.Equal(t, 1, handler.length)
|
||||||
require.Len(t, handler.generators, 1)
|
require.Len(t, handler.tagGenerators, 1)
|
||||||
require.Equal(t, tt.generator, handler.generators[0])
|
require.Equal(t, tt.generator, handler.tagGenerators[0])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
generators []TagJunkGenerator
|
generators []TagJunkGenerator
|
||||||
|
@ -49,12 +45,13 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||||
errMsg string
|
errMsg string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid consecutive indices",
|
name: "bad start",
|
||||||
generators: []TagJunkGenerator{
|
generators: []TagJunkGenerator{
|
||||||
newTagJunkGenerator("t1", 10),
|
newTagJunkGenerator("t3", 10),
|
||||||
newTagJunkGenerator("t2", 10),
|
newTagJunkGenerator("t4", 10),
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: true,
|
||||||
|
errMsg: "junk packet index should be consecutive",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-consecutive indices",
|
name: "non-consecutive indices",
|
||||||
|
@ -65,6 +62,16 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errMsg: "junk packet index should be consecutive",
|
errMsg: "junk packet index should be consecutive",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "consecutive indices",
|
||||||
|
generators: []TagJunkGenerator{
|
||||||
|
newTagJunkGenerator("t1", 10),
|
||||||
|
newTagJunkGenerator("t2", 10),
|
||||||
|
newTagJunkGenerator("t3", 10),
|
||||||
|
newTagJunkGenerator("t4", 10),
|
||||||
|
newTagJunkGenerator("t5", 10),
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "nameIndex error",
|
name: "nameIndex error",
|
||||||
generators: []TagJunkGenerator{
|
generators: []TagJunkGenerator{
|
||||||
|
@ -88,16 +95,14 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), tt.errMsg)
|
require.Contains(t, err.Error(), tt.errMsg)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
mockByte1 := []byte{0x01, 0x02}
|
mockByte1 := []byte{0x01, 0x02}
|
||||||
mockByte2 := []byte{0x03, 0x04, 0x05}
|
mockByte2 := []byte{0x03, 0x04, 0x05}
|
||||||
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
||||||
|
|
|
@ -20,7 +20,7 @@ const (
|
||||||
|
|
||||||
var generatorCreator = map[EnumTag]newGenerator{
|
var generatorCreator = map[EnumTag]newGenerator{
|
||||||
BytesEnumTag: newBytesGenerator,
|
BytesEnumTag: newBytesGenerator,
|
||||||
CounterEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
CounterEnumTag: newPacketCounterGenerator,
|
||||||
TimestampEnumTag: newTimestampGenerator,
|
TimestampEnumTag: newTimestampGenerator,
|
||||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||||
|
@ -89,6 +89,15 @@ func Parse(name, input string) (TagJunkGenerator, error) {
|
||||||
return TagJunkGenerator{}, fmt.Errorf("gen: %w", err)
|
return TagJunkGenerator{}, fmt.Errorf("gen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: handle counter tag
|
||||||
|
// if tag.Name == CounterEnumTag {
|
||||||
|
// packetCounter, ok := generator.(*PacketCounterGenerator)
|
||||||
|
// if !ok {
|
||||||
|
// log.Fatalf("packet counter generator expected, got %T", generator)
|
||||||
|
// }
|
||||||
|
// PacketCounter = packetCounter.counter
|
||||||
|
// }
|
||||||
|
|
||||||
rv.append(generator)
|
rv.append(generator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
func genAWGConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
||||||
var key1, key2 NoisePrivateKey
|
var key1, key2 NoisePrivateKey
|
||||||
_, err := rand.Read(key1[:])
|
_, err := rand.Read(key1[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -103,46 +103,35 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
}
|
}
|
||||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||||
|
|
||||||
cfgs[0] = uapiCfg(
|
args0 := append([]string(nil), cfg...)
|
||||||
|
args0 = append(args0, []string{
|
||||||
"private_key", hex.EncodeToString(key1[:]),
|
"private_key", hex.EncodeToString(key1[:]),
|
||||||
"listen_port", "0",
|
"listen_port", "0",
|
||||||
"replace_peers", "true",
|
"replace_peers", "true",
|
||||||
"jc", "5",
|
|
||||||
"jmin", "500",
|
|
||||||
"jmax", "1000",
|
|
||||||
"s1", "30",
|
|
||||||
"s2", "40",
|
|
||||||
"h1", "123456",
|
|
||||||
"h2", "67543",
|
|
||||||
"h4", "32345",
|
|
||||||
"h3", "123123",
|
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
"protocol_version", "1",
|
"protocol_version", "1",
|
||||||
"replace_allowed_ips", "true",
|
"replace_allowed_ips", "true",
|
||||||
"allowed_ip", "1.0.0.2/32",
|
"allowed_ip", "1.0.0.2/32",
|
||||||
)
|
}...)
|
||||||
|
cfgs[0] = uapiCfg(args0...)
|
||||||
|
|
||||||
endpointCfgs[0] = uapiCfg(
|
endpointCfgs[0] = uapiCfg(
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
"endpoint", "127.0.0.1:%d",
|
"endpoint", "127.0.0.1:%d",
|
||||||
)
|
)
|
||||||
cfgs[1] = uapiCfg(
|
|
||||||
|
args1 := append([]string(nil), cfg...)
|
||||||
|
args1 = append(args1, []string{
|
||||||
"private_key", hex.EncodeToString(key2[:]),
|
"private_key", hex.EncodeToString(key2[:]),
|
||||||
"listen_port", "0",
|
"listen_port", "0",
|
||||||
"replace_peers", "true",
|
"replace_peers", "true",
|
||||||
"jc", "5",
|
|
||||||
"jmin", "500",
|
|
||||||
"jmax", "1000",
|
|
||||||
"s1", "30",
|
|
||||||
"s2", "40",
|
|
||||||
"h1", "123456",
|
|
||||||
"h2", "67543",
|
|
||||||
"h4", "32345",
|
|
||||||
"h3", "123123",
|
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
"protocol_version", "1",
|
"protocol_version", "1",
|
||||||
"replace_allowed_ips", "true",
|
"replace_allowed_ips", "true",
|
||||||
"allowed_ip", "1.0.0.1/32",
|
"allowed_ip", "1.0.0.1/32",
|
||||||
)
|
}...)
|
||||||
|
|
||||||
|
cfgs[1] = uapiCfg(args1...)
|
||||||
endpointCfgs[1] = uapiCfg(
|
endpointCfgs[1] = uapiCfg(
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
"endpoint", "127.0.0.1:%d",
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
@ -214,11 +203,12 @@ func (pair *testPair) Send(
|
||||||
// genTestPair creates a testPair.
|
// genTestPair creates a testPair.
|
||||||
func genTestPair(
|
func genTestPair(
|
||||||
tb testing.TB,
|
tb testing.TB,
|
||||||
realSocket, withASecurity bool,
|
realSocket bool,
|
||||||
|
extraCfg ...string,
|
||||||
) (pair testPair) {
|
) (pair testPair) {
|
||||||
var cfg, endpointCfg [2]string
|
var cfg, endpointCfg [2]string
|
||||||
if withASecurity {
|
if len(extraCfg) > 0 {
|
||||||
cfg, endpointCfg = genASecurityConfigs(tb)
|
cfg, endpointCfg = genAWGConfigs(tb, extraCfg...)
|
||||||
} else {
|
} else {
|
||||||
cfg, endpointCfg = genConfigs(tb)
|
cfg, endpointCfg = genConfigs(tb)
|
||||||
}
|
}
|
||||||
|
@ -265,7 +255,7 @@ func genTestPair(
|
||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true, false)
|
pair := genTestPair(t, true)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
})
|
})
|
||||||
|
@ -275,9 +265,45 @@ func TestTwoDevicePing(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
||||||
func TestASecurityTwoDevicePing(t *testing.T) {
|
func TestAWGDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true, true)
|
pair := genTestPair(t, true,
|
||||||
|
"jc", "5",
|
||||||
|
"jmin", "500",
|
||||||
|
"jmax", "1000",
|
||||||
|
"s1", "30",
|
||||||
|
"s2", "40",
|
||||||
|
"h1", "123456",
|
||||||
|
"h2", "67543",
|
||||||
|
"h4", "32345",
|
||||||
|
"h3", "123123",
|
||||||
|
)
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
pair.Send(t, Ping, nil)
|
||||||
|
})
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
pair.Send(t, Pong, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||||
|
goroutineLeakCheck(t)
|
||||||
|
pair := genTestPair(t, true,
|
||||||
|
// "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||||
|
// "i2", "<b 0xf6ab3267fa><r 100>",
|
||||||
|
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||||
|
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||||
|
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||||
|
// "jc", "1",
|
||||||
|
// "jmin", "500",
|
||||||
|
// "jmax", "1000",
|
||||||
|
// "s1", "30",
|
||||||
|
// "s2", "40",
|
||||||
|
// "h1", "123456",
|
||||||
|
// "h2", "67543",
|
||||||
|
// "h4", "32345",
|
||||||
|
// "h3", "123123",
|
||||||
|
)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
})
|
})
|
||||||
|
@ -292,7 +318,7 @@ func TestUpDown(t *testing.T) {
|
||||||
const otrials = 10
|
const otrials = 10
|
||||||
|
|
||||||
for n := 0; n < otrials; n++ {
|
for n := 0; n < otrials; n++ {
|
||||||
pair := genTestPair(t, false, false)
|
pair := genTestPair(t, false)
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||||
|
@ -326,7 +352,7 @@ func TestUpDown(t *testing.T) {
|
||||||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||||
// It is intended to be used with the race detector to catch data races.
|
// It is intended to be used with the race detector to catch data races.
|
||||||
func TestConcurrencySafety(t *testing.T) {
|
func TestConcurrencySafety(t *testing.T) {
|
||||||
pair := genTestPair(t, true, false)
|
pair := genTestPair(t, true)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
const warmupIters = 10
|
const warmupIters = 10
|
||||||
|
@ -407,7 +433,7 @@ func TestConcurrencySafety(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLatency(b *testing.B) {
|
func BenchmarkLatency(b *testing.B) {
|
||||||
pair := genTestPair(b, true, false)
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -421,7 +447,7 @@ func BenchmarkLatency(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
pair := genTestPair(b, true, false)
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -465,7 +491,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUAPIGet(b *testing.B) {
|
func BenchmarkUAPIGet(b *testing.B) {
|
||||||
pair := genTestPair(b, true, false)
|
pair := genTestPair(b, true)
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
pair.Send(b, Pong, nil)
|
pair.Send(b, Pong, nil)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
|
@ -137,7 +138,10 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var totalLen uint64
|
var totalLen uint64
|
||||||
for _, b := range buffers {
|
for _, b := range buffers {
|
||||||
peer.device.awg.HandshakeHandler.PacketCounter++
|
// TODO
|
||||||
|
awg.PacketCounter.Inc()
|
||||||
|
peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
|
||||||
|
|
||||||
totalLen += uint64(len(b))
|
totalLen += uint64(len(b))
|
||||||
}
|
}
|
||||||
peer.txBytes.Add(totalLen)
|
peer.txBytes.Add(totalLen)
|
||||||
|
|
|
@ -134,14 +134,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
// set junks depending on packet type
|
// set junks depending on packet type
|
||||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||||
if junks == nil {
|
if junks == nil {
|
||||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
|
||||||
|
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
||||||
}
|
}
|
||||||
peer.device.awg.ASecMux.RUnlock()
|
peer.device.awg.ASecMux.RUnlock()
|
||||||
} else {
|
} else {
|
||||||
junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount)
|
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
|
||||||
}
|
}
|
||||||
peer.device.awg.ASecMux.RLock()
|
peer.device.awg.ASecMux.RLock()
|
||||||
err := peer.device.awg.JunkCreator.CreateJunkPackets(junks)
|
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
|
||||||
peer.device.awg.ASecMux.RUnlock()
|
peer.device.awg.ASecMux.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -5,6 +5,7 @@ go 1.24
|
||||||
require (
|
require (
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/tevino/abool v1.2.0
|
github.com/tevino/abool v1.2.0
|
||||||
|
go.uber.org/atomic v1.11.0
|
||||||
golang.org/x/crypto v0.36.0
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/net v0.37.0
|
golang.org/x/net v0.37.0
|
||||||
golang.org/x/sys v0.31.0
|
golang.org/x/sys v0.31.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -10,6 +10,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
|
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
|
||||||
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
|
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||||
|
|
Loading…
Add table
Reference in a new issue