mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-01 17:32:51 +02:00
fix: packet counter; test special handshake
This commit is contained in:
parent
f6c385f6a7
commit
e8dc69d407
9 changed files with 84 additions and 33 deletions
|
@ -4,12 +4,22 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tevino/abool"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: atomic/ and better way to use this
|
// TODO: atomic?/ and better way to use this
|
||||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
var WaitResponse = struct {
|
||||||
|
Channel chan struct{}
|
||||||
|
ShouldWait *abool.AtomicBool
|
||||||
|
}{
|
||||||
|
make(chan struct{}, 1),
|
||||||
|
abool.New(),
|
||||||
|
}
|
||||||
|
|
||||||
type SpecialHandshakeHandler struct {
|
type SpecialHandshakeHandler struct {
|
||||||
isFirstDone bool
|
isFirstDone bool
|
||||||
SpecialJunk TagJunkGeneratorHandler
|
SpecialJunk TagJunkGeneratorHandler
|
||||||
|
@ -39,7 +49,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||||
// TODO: create tests
|
// TODO: create tests
|
||||||
if !handler.isFirstDone {
|
if !handler.isFirstDone {
|
||||||
handler.isFirstDone = true
|
handler.isFirstDone = true
|
||||||
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +58,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
rv := handler.SpecialJunk.GeneratePackets()
|
rv := handler.SpecialJunk.GeneratePackets()
|
||||||
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,6 @@ type TimestampGenerator struct {
|
||||||
func (tg *TimestampGenerator) Generate() []byte {
|
func (tg *TimestampGenerator) Generate() []byte {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||||
fmt.Printf("timestamp: %v\n", buf)
|
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +119,6 @@ type WaitTimeoutGenerator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
||||||
fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
|
|
||||||
time.Sleep(wtg.waitTimeout)
|
time.Sleep(wtg.waitTimeout)
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
@ -130,27 +128,25 @@ func (wtg *WaitTimeoutGenerator) Size() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||||
t, err := strconv.Atoi(param)
|
timeout, err := strconv.Atoi(param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("timeout parse int: %w", err)
|
return nil, fmt.Errorf("timeout parse int: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t > 5000 {
|
if timeout > 5000 {
|
||||||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
|
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketCounterGenerator struct {
|
type PacketCounterGenerator struct {
|
||||||
// counter *atomic.Uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PacketCounterGenerator) Generate() []byte {
|
func (c *PacketCounterGenerator) Generate() []byte {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
// TODO: better way to handle counter tag
|
// TODO: better way to handle counter tag
|
||||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||||
fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
|
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,6 +159,27 @@ func newPacketCounterGenerator(param string) (Generator, error) {
|
||||||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
|
|
||||||
return &PacketCounterGenerator{}, nil
|
return &PacketCounterGenerator{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WaitResponseGenerator struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WaitResponseGenerator) Generate() []byte {
|
||||||
|
WaitResponse.ShouldWait.Set()
|
||||||
|
<-WaitResponse.Channel
|
||||||
|
WaitResponse.ShouldWait.UnSet()
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WaitResponseGenerator) Size() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||||
|
if len(param) != 0 {
|
||||||
|
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WaitResponseGenerator{}, nil
|
||||||
|
}
|
||||||
|
|
|
@ -43,10 +43,13 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
|
||||||
|
|
||||||
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
|
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
|
||||||
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
|
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
|
||||||
|
|
||||||
for i, tagGenerator := range handler.tagGenerators {
|
for i, tagGenerator := range handler.tagGenerators {
|
||||||
|
PacketCounter.Inc()
|
||||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||||
copy(rv[i], tagGenerator.generatePacket())
|
copy(rv[i], tagGenerator.generatePacket())
|
||||||
}
|
}
|
||||||
|
PacketCounter.Add(uint64(handler.DefaultJunkCount))
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,12 +19,12 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var generatorCreator = map[EnumTag]newGenerator{
|
var generatorCreator = map[EnumTag]newGenerator{
|
||||||
BytesEnumTag: newBytesGenerator,
|
BytesEnumTag: newBytesGenerator,
|
||||||
CounterEnumTag: newPacketCounterGenerator,
|
CounterEnumTag: newPacketCounterGenerator,
|
||||||
TimestampEnumTag: newTimestampGenerator,
|
TimestampEnumTag: newTimestampGenerator,
|
||||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||||
WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
// WaitResponseEnumTag: newWaitResponseGenerator,
|
||||||
}
|
}
|
||||||
|
|
||||||
// helper map to determine enumTags are unique
|
// helper map to determine enumTags are unique
|
||||||
|
|
|
@ -176,7 +176,7 @@ func (pair *testPair) Send(
|
||||||
}
|
}
|
||||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||||
p1.tun.Outbound <- msg
|
p1.tun.Outbound <- msg
|
||||||
timer := time.NewTimer(5 * time.Second)
|
timer := time.NewTimer(6 * time.Second)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
var err error
|
var err error
|
||||||
select {
|
select {
|
||||||
|
@ -289,11 +289,12 @@ func TestAWGDevicePing(t *testing.T) {
|
||||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true,
|
pair := genTestPair(t, true,
|
||||||
// "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||||
// "i2", "<b 0xf6ab3267fa><r 100>",
|
"i2", "<b 0xf6ab3267fa><r 100>",
|
||||||
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||||
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||||
"j3", "<t><b 0xf6ab><c><r 10>",
|
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||||
|
"itime", "60",
|
||||||
// "jc", "1",
|
// "jc", "1",
|
||||||
// "jmin", "500",
|
// "jmin", "500",
|
||||||
// "jmax", "1000",
|
// "jmax", "1000",
|
||||||
|
@ -305,10 +306,16 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||||
// "h3", "123123",
|
// "h3", "123123",
|
||||||
)
|
)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
for {
|
||||||
|
pair.Send(t, Ping, nil)
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
pair.Send(t, Pong, nil)
|
for {
|
||||||
|
pair.Send(t, Pong, nil)
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,10 +138,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var totalLen uint64
|
var totalLen uint64
|
||||||
for _, b := range buffers {
|
for _, b := range buffers {
|
||||||
// TODO
|
|
||||||
awg.PacketCounter.Inc()
|
|
||||||
peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
|
|
||||||
|
|
||||||
totalLen += uint64(len(b))
|
totalLen += uint64(len(b))
|
||||||
}
|
}
|
||||||
peer.txBytes.Add(totalLen)
|
peer.txBytes.Add(totalLen)
|
||||||
|
@ -149,6 +145,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error {
|
||||||
|
err := peer.SendBuffers(buffers)
|
||||||
|
if err == nil {
|
||||||
|
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
// The awful goo that follows is identical to:
|
// The awful goo that follows is identical to:
|
||||||
//
|
//
|
||||||
|
|
|
@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
}
|
}
|
||||||
|
|
||||||
// check size of packet
|
// check size of packet
|
||||||
|
|
||||||
packet := bufsArrs[i][:size]
|
packet := bufsArrs[i][:size]
|
||||||
var msgType uint32
|
var msgType uint32
|
||||||
if device.isAdvancedSecurityOn() {
|
if device.isAdvancedSecurityOn() {
|
||||||
|
// TODO:
|
||||||
|
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||||
|
// awg.WaitResponse.Channel <- struct{}{}
|
||||||
|
// }
|
||||||
|
|
||||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||||
// transport size can align with other header types;
|
// transport size can align with other header types;
|
||||||
|
|
|
@ -134,8 +134,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
// set junks depending on packet type
|
// set junks depending on packet type
|
||||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||||
if junks == nil {
|
if junks == nil {
|
||||||
peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
|
|
||||||
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
||||||
|
if junks != nil {
|
||||||
|
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||||
}
|
}
|
||||||
peer.device.awg.ASecMux.RUnlock()
|
peer.device.awg.ASecMux.RUnlock()
|
||||||
} else {
|
} else {
|
||||||
|
@ -186,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
|
|
||||||
sendBuffer = append(sendBuffer, junkedHeader)
|
sendBuffer = append(sendBuffer, junkedHeader)
|
||||||
|
|
||||||
err = peer.SendBuffers(sendBuffer)
|
err = peer.SendBuffersCountPacket(sendBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||||
}
|
}
|
||||||
|
@ -242,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
// TODO: allocation could be avoided
|
// TODO: allocation could be avoided
|
||||||
err = peer.SendBuffers([][]byte{junkedHeader})
|
err = peer.SendBuffersCountPacket([][]byte{junkedHeader})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||||
}
|
}
|
||||||
|
@ -596,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
err := peer.SendBuffers(bufs)
|
err := peer.SendBuffersCountPacket(bufs)
|
||||||
if dataSent {
|
if dataSent {
|
||||||
peer.timersDataSent()
|
peer.timersDataSent()
|
||||||
}
|
}
|
||||||
|
|
|
@ -394,9 +394,9 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Updating itime %s", itime)
|
device.log.Verbosef("UAPI: Updating itime %d", itime)
|
||||||
|
|
||||||
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime)
|
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
|
||||||
tempAwg.HandshakeHandler.IsSet = true
|
tempAwg.HandshakeHandler.IsSet = true
|
||||||
default:
|
default:
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||||
|
|
Loading…
Add table
Reference in a new issue