fix: packet counter; test special handshake

This commit is contained in:
Mark Puha 2025-06-12 06:02:45 +02:00
parent f6c385f6a7
commit e8dc69d407
9 changed files with 84 additions and 33 deletions

View file

@ -4,12 +4,22 @@ import (
"errors"
"time"
"github.com/tevino/abool"
"go.uber.org/atomic"
)
// TODO: atomic/ and better way to use this
// TODO: atomic?/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
// TODO
var WaitResponse = struct {
Channel chan struct{}
ShouldWait *abool.AtomicBool
}{
make(chan struct{}, 1),
abool.New(),
}
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkGeneratorHandler
@ -39,7 +49,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
handler.nextItime = time.Now().Add(handler.ITimeout)
return nil
}
@ -48,7 +58,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
}
rv := handler.SpecialJunk.GeneratePackets()
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
handler.nextItime = time.Now().Add(handler.ITimeout)
return rv
}

View file

@ -99,7 +99,6 @@ type TimestampGenerator struct {
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
fmt.Printf("timestamp: %v\n", buf)
return buf
}
@ -120,7 +119,6 @@ type WaitTimeoutGenerator struct {
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds())
time.Sleep(wtg.waitTimeout)
return []byte{}
}
@ -130,27 +128,25 @@ func (wtg *WaitTimeoutGenerator) Size() int {
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
t, err := strconv.Atoi(param)
timeout, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
if t > 5000 {
if timeout > 5000 {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil
return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil
}
type PacketCounterGenerator struct {
// counter *atomic.Uint64
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf)
return buf
}
@ -163,6 +159,27 @@ func newPacketCounterGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
// return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil
return &PacketCounterGenerator{}, nil
}
type WaitResponseGenerator struct {
}
func (c *WaitResponseGenerator) Generate() []byte {
WaitResponse.ShouldWait.Set()
<-WaitResponse.Channel
WaitResponse.ShouldWait.UnSet()
return []byte{}
}
func (c *WaitResponseGenerator) Size() int {
return 0
}
func newWaitResponseGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
}
return &WaitResponseGenerator{}, nil
}

View file

@ -43,10 +43,13 @@ func (handler *TagJunkGeneratorHandler) Validate() error {
func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount)
for i, tagGenerator := range handler.tagGenerators {
PacketCounter.Inc()
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
}
PacketCounter.Add(uint64(handler.DefaultJunkCount))
return rv
}

View file

@ -19,12 +19,12 @@ const (
)
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
// WaitResponseEnumTag: newWaitResponseGenerator,
}
// helper map to determine enumTags are unique

View file

@ -176,7 +176,7 @@ func (pair *testPair) Send(
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(5 * time.Second)
timer := time.NewTimer(6 * time.Second)
defer timer.Stop()
var err error
select {
@ -289,11 +289,12 @@ func TestAWGDevicePing(t *testing.T) {
func TestAWGHandshakeDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
// "i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
// "i2", "<b 0xf6ab3267fa><r 100>",
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
"i2", "<b 0xf6ab3267fa><r 100>",
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
"j2", "<c><b 0xf6ab><t><wt 1000>",
"j3", "<t><b 0xf6ab><c><r 10>",
"itime", "60",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",
@ -305,10 +306,16 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
for {
pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second)
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
for {
pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second)
}
})
}

View file

@ -138,10 +138,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil {
var totalLen uint64
for _, b := range buffers {
// TODO
awg.PacketCounter.Inc()
peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint)
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
@ -149,6 +145,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return err
}
func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) String() string {
// The awful goo that follows is identical to:
//

View file

@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
}
// check size of packet
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAdvancedSecurityOn() {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;

View file

@ -134,8 +134,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer)
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
if junks != nil {
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
}
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.ASecMux.RUnlock()
} else {
@ -186,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendBuffers(sendBuffer)
err = peer.SendBuffersCountPacket(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@ -242,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
err = peer.SendBuffersCountPacket([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@ -596,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendBuffers(bufs)
err := peer.SendBuffersCountPacket(bufs)
if dataSent {
peer.timersDataSent()
}

View file

@ -394,9 +394,9 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
}
device.log.Verbosef("UAPI: Updating itime %s", itime)
device.log.Verbosef("UAPI: Updating itime %d", itime)
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime)
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
tempAwg.HandshakeHandler.IsSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)