diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go index 6ba29e8..fb83a53 100644 --- a/device/awg/special_handshake_handler.go +++ b/device/awg/special_handshake_handler.go @@ -4,12 +4,22 @@ import ( "errors" "time" + "github.com/tevino/abool" "go.uber.org/atomic" ) -// TODO: atomic/ and better way to use this +// TODO: atomic?/ and better way to use this var PacketCounter *atomic.Uint64 = atomic.NewUint64(0) +// TODO +var WaitResponse = struct { + Channel chan struct{} + ShouldWait *abool.AtomicBool +}{ + make(chan struct{}, 1), + abool.New(), +} + type SpecialHandshakeHandler struct { isFirstDone bool SpecialJunk TagJunkGeneratorHandler @@ -39,7 +49,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { // TODO: create tests if !handler.isFirstDone { handler.isFirstDone = true - handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout)) + handler.nextItime = time.Now().Add(handler.ITimeout) return nil } @@ -48,7 +58,7 @@ func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { } rv := handler.SpecialJunk.GeneratePackets() - handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout)) + handler.nextItime = time.Now().Add(handler.ITimeout) return rv } diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go index 6e18384..4aaa727 100644 --- a/device/awg/tag_generator.go +++ b/device/awg/tag_generator.go @@ -99,7 +99,6 @@ type TimestampGenerator struct { func (tg *TimestampGenerator) Generate() []byte { buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix())) - fmt.Printf("timestamp: %v\n", buf) return buf } @@ -120,7 +119,6 @@ type WaitTimeoutGenerator struct { } func (wtg *WaitTimeoutGenerator) Generate() []byte { - fmt.Printf("sleep: %d\n", wtg.waitTimeout.Milliseconds()) time.Sleep(wtg.waitTimeout) return []byte{} } @@ -130,27 +128,25 @@ func (wtg *WaitTimeoutGenerator) Size() int { } func newWaitTimeoutGenerator(param string) (Generator, error) { - t, err := strconv.Atoi(param) + timeout, err := strconv.Atoi(param) if err != nil { return nil, fmt.Errorf("timeout parse int: %w", err) } - if t > 5000 { + if timeout > 5000 { return nil, fmt.Errorf("timeout must be less than 5000ms") } - return &WaitTimeoutGenerator{waitTimeout: time.Duration(t) * time.Millisecond}, nil + return &WaitTimeoutGenerator{waitTimeout: time.Duration(timeout) * time.Millisecond}, nil } type PacketCounterGenerator struct { - // counter *atomic.Uint64 } func (c *PacketCounterGenerator) Generate() []byte { buf := make([]byte, 8) // TODO: better way to handle counter tag binary.BigEndian.PutUint64(buf, PacketCounter.Load()) - fmt.Printf("packet %d; counter: %v\n", PacketCounter.Load(), buf) return buf } @@ -163,6 +159,27 @@ func newPacketCounterGenerator(param string) (Generator, error) { return nil, fmt.Errorf("packet counter param needs to be empty: %s", param) } - // return &PacketCounterGenerator{counter: atomic.NewUint64(0)}, nil return &PacketCounterGenerator{}, nil } + +type WaitResponseGenerator struct { +} + +func (c *WaitResponseGenerator) Generate() []byte { + WaitResponse.ShouldWait.Set() + <-WaitResponse.Channel + WaitResponse.ShouldWait.UnSet() + return []byte{} +} + +func (c *WaitResponseGenerator) Size() int { + return 0 +} + +func newWaitResponseGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("wait response param needs to be empty: %s", param) + } + + return &WaitResponseGenerator{}, nil +} diff --git a/device/awg/tag_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go index dfa53cf..934dadf 100644 --- a/device/awg/tag_junk_generator_handler.go +++ b/device/awg/tag_junk_generator_handler.go @@ -43,10 +43,13 @@ func (handler *TagJunkGeneratorHandler) Validate() error { func (handler *TagJunkGeneratorHandler) GeneratePackets() [][]byte { var rv = make([][]byte, 0, handler.length+handler.DefaultJunkCount) + for i, tagGenerator := range handler.tagGenerators { + PacketCounter.Inc() rv = append(rv, make([]byte, tagGenerator.packetSize)) copy(rv[i], tagGenerator.generatePacket()) } + PacketCounter.Add(uint64(handler.DefaultJunkCount)) return rv } diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go index d6dba1b..1180ac6 100644 --- a/device/awg/tag_parser.go +++ b/device/awg/tag_parser.go @@ -19,12 +19,12 @@ const ( ) var generatorCreator = map[EnumTag]newGenerator{ - BytesEnumTag: newBytesGenerator, - CounterEnumTag: newPacketCounterGenerator, - TimestampEnumTag: newTimestampGenerator, - RandomBytesEnumTag: newRandomPacketGenerator, - WaitTimeoutEnumTag: newWaitTimeoutGenerator, - WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, + BytesEnumTag: newBytesGenerator, + CounterEnumTag: newPacketCounterGenerator, + TimestampEnumTag: newTimestampGenerator, + RandomBytesEnumTag: newRandomPacketGenerator, + WaitTimeoutEnumTag: newWaitTimeoutGenerator, + // WaitResponseEnumTag: newWaitResponseGenerator, } // helper map to determine enumTags are unique diff --git a/device/device_test.go b/device/device_test.go index b57f3e6..a72159a 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -176,7 +176,7 @@ func (pair *testPair) Send( } msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg - timer := time.NewTimer(5 * time.Second) + timer := time.NewTimer(6 * time.Second) defer timer.Stop() var err error select { @@ -289,11 +289,12 @@ func TestAWGDevicePing(t *testing.T) { func TestAWGHandshakeDevicePing(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true, - // "i1", "", - // "i2", "", + "i1", "", + "i2", "", "j1", "", "j2", "", "j3", "", + "itime", "60", // "jc", "1", // "jmin", "500", // "jmax", "1000", @@ -305,10 +306,16 @@ func TestAWGHandshakeDevicePing(t *testing.T) { // "h3", "123123", ) t.Run("ping 1.0.0.1", func(t *testing.T) { - pair.Send(t, Ping, nil) + for { + pair.Send(t, Ping, nil) + time.Sleep(2 * time.Second) + } }) t.Run("ping 1.0.0.2", func(t *testing.T) { - pair.Send(t, Pong, nil) + for { + pair.Send(t, Pong, nil) + time.Sleep(2 * time.Second) + } }) } diff --git a/device/peer.go b/device/peer.go index 7705b6f..9c7eaab 100644 --- a/device/peer.go +++ b/device/peer.go @@ -138,10 +138,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { if err == nil { var totalLen uint64 for _, b := range buffers { - // TODO - awg.PacketCounter.Inc() - peer.device.log.Verbosef("%v - Sending %d bytes to %s; pc: %d", peer, len(b), endpoint) - totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) @@ -149,6 +145,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return err } +func (peer *Peer) SendBuffersCountPacket(buffers [][]byte) error { + err := peer.SendBuffers(buffers) + if err == nil { + awg.PacketCounter.Add(uint64(len(buffers))) + return nil + } + + return err +} + func (peer *Peer) String() string { // The awful goo that follows is identical to: // diff --git a/device/receive.go b/device/receive.go index 1875c76..6eaa746 100644 --- a/device/receive.go +++ b/device/receive.go @@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming( } // check size of packet - packet := bufsArrs[i][:size] var msgType uint32 if device.isAdvancedSecurityOn() { + // TODO: + // if awg.WaitResponse.ShouldWait.IsSet() { + // awg.WaitResponse.Channel <- struct{}{} + // } + if assumedMsgType, ok := packetSizeToMsgType[size]; ok { junkSize := msgTypeToJunkSize[assumedMsgType] // transport size can align with other header types; diff --git a/device/send.go b/device/send.go index f87d4ae..74015c2 100644 --- a/device/send.go +++ b/device/send.go @@ -134,8 +134,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { // set junks depending on packet type junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() if junks == nil { - peer.device.log.Verbosef("%v - No special junks defined, using controlled", peer) junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk() + if junks != nil { + peer.device.log.Verbosef("%v - Controlled junks sent", peer) + } + } else { + peer.device.log.Verbosef("%v - Special junks sent", peer) } peer.device.awg.ASecMux.RUnlock() } else { @@ -186,7 +190,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { sendBuffer = append(sendBuffer, junkedHeader) - err = peer.SendBuffers(sendBuffer) + err = peer.SendBuffersCountPacket(sendBuffer) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -242,7 +246,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{junkedHeader}) + err = peer.SendBuffersCountPacket([][]byte{junkedHeader}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -596,7 +600,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendBuffersCountPacket(bufs) if dataSent { peer.timersDataSent() } diff --git a/device/uapi.go b/device/uapi.go index 216f7cd..1e95e48 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -394,9 +394,9 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err) } - device.log.Verbosef("UAPI: Updating itime %s", itime) + device.log.Verbosef("UAPI: Updating itime %d", itime) - tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) + tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second tempAwg.HandshakeHandler.IsSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)