mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-30 08:32:50 +02:00
fix: packet counter; test special handshake
This commit is contained in:
parent
f6c385f6a7
commit
e8dc69d407
9 changed files with 84 additions and 33 deletions
|
@ -4,12 +4,22 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TODO: atomic/ and better way to use this
|
||||
// TODO: atomic?/ and better way to use this
|
||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||
|
||||
// TODO
|
||||
var WaitResponse = struct {
|
||||
Channel chan struct{}
|
||||
ShouldWait *abool.AtomicBool
|
||||
}{
|
||||
make(chan struct{}, 1),
|
||||
abool.New(),
|
||||
}
|
||||
|
||||
type SpecialHandshakeHandler struct {
|
||||
isFirstDone bool
|
||||
SpecialJunk TagJunkGeneratorHandler
|
||||
|
@ -39,7 +49,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
|||
// TODO: create tests
|
||||
if !handler.isFirstDone {
|
||||
handler.isFirstDone = true
|
||||
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -48,7 +58,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
|||
}
|
||||
|
||||
rv := handler.SpecialJunk.GeneratePackets()
|
||||
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
|
||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||
|
||||
return rv
|
||||
}
|
||||
|
|
|
@ -99,7 +99,6 @@ type TimestampGenerator struct {
|
|||
func (tg *TimestampGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||
fmt.Printf("timestamp: %v\n", buf)
|
||||
return buf
|
||||
}
|
||||
|
||||
|
@ -120,7 +119,6 @@ type WaitTimeoutGenerator struct {
|
|||
}
|
||||
|
||||
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
||||
fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
|
||||
time.Sleep(wtg.waitTimeout)
|
||||
return []byte{}
|
||||
}
|
||||
|
@ -130,27 +128,25 @@ func (wtg *WaitTimeoutGenerator) Size() int {
|
|||
}
|
||||
|
||||
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||
t, err := strconv.Atoi(param)
|
||||
timeout, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("timeout parse int: %w", err)
|
||||
}
|
||||
|
||||
if t > 5000 {
|
||||
if timeout > 5000 {
|
||||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||
}
|
||||
|
||||
return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
|
||||
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
// counter *atomic.Uint64
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
// TODO: better way to handle counter tag
|
||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||
fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
|
||||
return buf
|
||||
}
|
||||
|
||||
|
@ -163,6 +159,27 @@ func newPacketCounterGenerator(param string) (Generator, error) {
|
|||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
// return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
|
||||
return &PacketCounterGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitResponseGenerator struct {
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Generate() []byte {
|
||||
WaitResponse.ShouldWait.Set()
|
||||
<-WaitResponse.Channel
|
||||
WaitResponse.ShouldWait.UnSet()
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &WaitResponseGenerator{}, nil
|
||||
}
|
||||
|
|
|
@ -43,10 +43,13 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
|
|||
|
||||
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
|
||||
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range handler.tagGenerators {
|
||||
PacketCounter.Inc()
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
}
|
||||
PacketCounter.Add(uint64(handler.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
|
|
@ -19,12 +19,12 @@ const (
|
|||
)
|
||||
|
||||
var generatorCreator = map[EnumTag]newGenerator{
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||
WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||
// WaitResponseEnumTag: newWaitResponseGenerator,
|
||||
}
|
||||
|
||||
// helper map to determine enumTags are unique
|
||||
|
|
|
@ -176,7 +176,7 @@ func (pair *testPair) Send(
|
|||
}
|
||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||
p1.tun.Outbound <- msg
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
timer := time.NewTimer(6 * time.Second)
|
||||
defer timer.Stop()
|
||||
var err error
|
||||
select {
|
||||
|
@ -289,11 +289,12 @@ func TestAWGDevicePing(t *testing.T) {
|
|||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
// "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
// "i2", "<b 0xf6ab3267fa><r 100>",
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
"i2", "<b 0xf6ab3267fa><r 100>",
|
||||
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||
"itime", "60",
|
||||
// "jc", "1",
|
||||
// "jmin", "500",
|
||||
// "jmax", "1000",
|
||||
|
@ -305,10 +306,16 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
|
|||
// "h3", "123123",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
for {
|
||||
pair.Send(t, Ping, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
pair.Send(t, Pong, nil)
|
||||
for {
|
||||
pair.Send(t, Pong, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -138,10 +138,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
|||
if err == nil {
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
// TODO
|
||||
awg.PacketCounter.Inc()
|
||||
peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
|
||||
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
|
@ -149,6 +145,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error {
|
||||
err := peer.SendBuffers(buffers)
|
||||
if err == nil {
|
||||
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) String() string {
|
||||
// The awful goo that follows is identical to:
|
||||
//
|
||||
|
|
|
@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
var msgType uint32
|
||||
if device.isAdvancedSecurityOn() {
|
||||
// TODO:
|
||||
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||
// awg.WaitResponse.Channel <- struct{}{}
|
||||
// }
|
||||
|
||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||
// transport size can align with other header types;
|
||||
|
|
|
@ -134,8 +134,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
// set junks depending on packet type
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||
if junks == nil {
|
||||
peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
||||
if junks != nil {
|
||||
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
|
||||
}
|
||||
} else {
|
||||
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||
}
|
||||
peer.device.awg.ASecMux.RUnlock()
|
||||
} else {
|
||||
|
@ -186,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
|
||||
sendBuffer = append(sendBuffer, junkedHeader)
|
||||
|
||||
err = peer.SendBuffers(sendBuffer)
|
||||
err = peer.SendBuffersCountPacket(sendBuffer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
|
@ -242,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{junkedHeader})
|
||||
err = peer.SendBuffersCountPacket([][]byte{junkedHeader})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
|
@ -596,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendBuffers(bufs)
|
||||
err := peer.SendBuffersCountPacket(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
|
|
|
@ -394,9 +394,9 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating itime %s", itime)
|
||||
device.log.Verbosef("UAPI: Updating itime %d", itime)
|
||||
|
||||
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime)
|
||||
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
|
|
Loading…
Add table
Reference in a new issue