From 05fbf0feb0db63d55695dea095d33657e4d0766d Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Wed, 2 Jul 2025 19:59:53 +0200 Subject: [PATCH] feat: add s3, s4 --- device/awg/awg.go | 59 +++++++++++++--- device/awg/junk_creator.go | 1 + device/awg/junk_creator_test.go | 4 +- device/device.go | 120 ++++++++++++++++++++++---------- device/device_test.go | 34 ++++++--- device/noise-protocol.go | 7 +- device/peer.go | 20 +++--- device/receive.go | 5 +- device/send.go | 57 ++++++++------- device/uapi.go | 36 ++++++++-- 10 files changed, 235 insertions(+), 108 deletions(-) diff --git a/device/awg/awg.go b/device/awg/awg.go index a2da6f1..b4005ae 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -1,11 +1,27 @@ package awg import ( + "bytes" "sync" "github.com/tevino/abool" ) +type aSecCfgType struct { + IsSet bool + JunkPacketCount int + JunkPacketMinSize int + JunkPacketMaxSize int + InitHeaderJunkSize int + ResponseHeaderJunkSize int + CookieReplyHeaderJunkSize int + TransportHeaderJunkSize int + InitPacketMagicHeader uint32 + ResponsePacketMagicHeader uint32 + UnderloadPacketMagicHeader uint32 + TransportPacketMagicHeader uint32 +} + type Protocol struct { IsASecOn abool.AtomicBool // TODO: revision the need of the mutex @@ -16,15 +32,36 @@ type Protocol struct { HandshakeHandler SpecialHandshakeHandler } -type aSecCfgType struct { - IsSet bool - JunkPacketCount int - JunkPacketMinSize int - JunkPacketMaxSize int - InitPacketJunkSize int - ResponsePacketJunkSize int - InitPacketMagicHeader uint32 - ResponsePacketMagicHeader uint32 - UnderloadPacketMagicHeader uint32 - TransportPacketMagicHeader uint32 +func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize) +} + +func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize) +} + +func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize) +} + +func (protocol *Protocol) CreateTransportHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize) +} + +func (protocol *Protocol) createHeaderJunk(junkSize int) ([]byte, error) { + var junk []byte + protocol.ASecMux.RLock() + if junkSize != 0 { + buf := make([]byte, 0, junkSize) + writer := bytes.NewBuffer(buf[:0]) + err := protocol.JunkCreator.AppendJunk(writer, junkSize) + if err != nil { + protocol.ASecMux.RUnlock() + return nil, err + } + junk = writer.Bytes() + } + protocol.ASecMux.RUnlock() + + return junk, nil } diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go index 08294b1..91fd253 100644 --- a/device/awg/junk_creator.go +++ b/device/awg/junk_creator.go @@ -12,6 +12,7 @@ type junkCreator struct { cha8Rand *v2.ChaCha8 } +// TODO: refactor param to only pass the junk related params func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { buf := make([]byte, 32) _, err := crand.Read(buf) diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go index 9f80d83..424f104 100644 --- a/device/awg/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -12,8 +12,8 @@ func setUpJunkCreator(t *testing.T) (junkCreator, error) { JunkPacketCount: 5, JunkPacketMinSize: 500, JunkPacketMaxSize: 1000, - InitPacketJunkSize: 30, - ResponsePacketJunkSize: 40, + InitHeaderJunkSize: 30, + ResponseHeaderJunkSize: 40, InitPacketMagicHeader: 123456, ResponsePacketMagicHeader: 67543, UnderloadPacketMagicHeader: 32345, diff --git a/device/device.go b/device/device.go index 42a81c7..ac81c84 100644 --- a/device/device.go +++ b/device/device.go @@ -646,38 +646,111 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { isASecOn = true } - if MessageInitiationSize+tempAwg.ASecCfg.InitPacketJunkSize >= MaxSegmentSize { + newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize + + if newInitSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.InitPacketJunkSize, + tempAwg.ASecCfg.InitHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.InitPacketJunkSize = tempAwg.ASecCfg.InitPacketJunkSize + device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize } - if tempAwg.ASecCfg.InitPacketJunkSize != 0 { + if tempAwg.ASecCfg.InitHeaderJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempAwg.ASecCfg.ResponsePacketJunkSize >= MaxSegmentSize { + newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize + + if newResponseSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.ResponsePacketJunkSize, + tempAwg.ASecCfg.ResponseHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.ResponsePacketJunkSize = tempAwg.ASecCfg.ResponsePacketJunkSize + device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize } - if tempAwg.ASecCfg.ResponsePacketJunkSize != 0 { + if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 { isASecOn = true } + newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize + + if newCookieSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.CookieReplyHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize + } + + if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + isASecOn = true + } + + newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize + + if newTransportSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.TransportHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize + } + + if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 { + isASecOn = true + } + + isSameSizeMap := map[int]struct{}{ + newInitSize: {}, + newResponseSize: {}, + newCookieSize: {}, + newTransportSize: {}, + } + + if len(isSameSizeMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`, + newInitSize, + newResponseSize, + newCookieSize, + newTransportSize, + ), + ) + } else { + packetSizeToMsgType = map[int]uint32{ + newInitSize: MessageInitiationType, + newResponseSize: MessageResponseType, + newCookieSize: MessageCookieReplyType, + newTransportSize: MessageTransportType, + } + + msgTypeToJunkSize = map[uint32]int{ + MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize, + MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize, + MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize, + MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize, + } + } + if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") @@ -718,7 +791,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { MessageTransportType = DefaultMessageTransportType } - isSameMap := map[uint32]struct{}{ + isSameHeaderMap := map[uint32]struct{}{ MessageInitiationType: {}, MessageResponseType: {}, MessageCookieReplyType: {}, @@ -726,7 +799,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } // size will be different if same values - if len(isSameMap) != 4 { + if len(isSameHeaderMap) != 4 { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, @@ -738,33 +811,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ) } - newInitSize := MessageInitiationSize + device.awg.ASecCfg.InitPacketJunkSize - newResponseSize := MessageResponseSize + device.awg.ASecCfg.ResponsePacketJunkSize - - if newInitSize == newResponseSize { - errs = append(errs, ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ`, - newInitSize, - newResponseSize, - ), - ) - } else { - packetSizeToMsgType = map[int]uint32{ - newInitSize: MessageInitiationType, - newResponseSize: MessageResponseType, - MessageCookieReplySize: MessageCookieReplyType, - MessageTransportSize: MessageTransportType, - } - - msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.awg.ASecCfg.InitPacketJunkSize, - MessageResponseType: device.awg.ASecCfg.ResponsePacketJunkSize, - MessageCookieReplyType: 0, - MessageTransportType: 0, - } - } - device.awg.IsASecOn.SetTo(isASecOn) var err error device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) diff --git a/device/device_test.go b/device/device_test.go index 7aec4c2..f9eb50c 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -17,6 +17,7 @@ import ( "os/signal" "runtime" "runtime/pprof" + "strconv" "sync" "testing" "time" @@ -129,6 +130,7 @@ func (pair *testPair) Send( tb testing.TB, ping SendDirection, done chan struct{}, + optTransportJunk ...int, ) { tb.Helper() p0, p1 := pair[0], pair[1] @@ -136,6 +138,12 @@ func (pair *testPair) Send( // pong is the new ping p0, p1 = p1, p0 } + + transportJunk := 0 + if len(optTransportJunk) > 0 { + transportJunk = optTransportJunk[0] + } + msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg timer := time.NewTimer(6 * time.Second) @@ -143,7 +151,10 @@ func (pair *testPair) Send( var err error select { case msgRecv := <-p0.tun.Inbound: - if !bytes.Equal(msg, msgRecv) { + fmt.Printf("%x\n", msg) + fmt.Printf("%x\n", msgRecv[transportJunk:]) + fmt.Printf("%x\n", msgRecv) + if !bytes.Equal(msg, msgRecv[transportJunk:]) { err = fmt.Errorf("%s did not transit correctly", ping) } case <-timer.C: @@ -226,22 +237,27 @@ func TestTwoDevicePing(t *testing.T) { // Run test with -race=false to avoid the race for setting the default msgTypes 2 times func TestAWGDevicePing(t *testing.T) { goroutineLeakCheck(t) + + transportJunk := 5 + pair := genTestPair(t, true, "jc", "5", "jmin", "500", "jmax", "1000", - "s1", "30", - "s2", "40", - "h1", "123456", - "h2", "67543", - "h4", "32345", - "h3", "123123", + // "s1", "30", + // "s2", "40", + "s3", "50", + "s4", strconv.Itoa(transportJunk), + // "h1", "123456", + // "h2", "67543", + // "h3", "123123", + // "h4", "32345", ) t.Run("ping 1.0.0.1", func(t *testing.T) { - pair.Send(t, Ping, nil) + pair.Send(t, Ping, nil, transportJunk) }) t.Run("ping 1.0.0.2", func(t *testing.T) { - pair.Send(t, Pong, nil) + pair.Send(t, Pong, nil, transportJunk) }) } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 89c634c..ff1ecc1 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -82,9 +82,10 @@ const ( MessageTransportOffsetContent = 16 ) -var packetSizeToMsgType map[int]uint32 - -var msgTypeToJunkSize map[uint32]int +var ( + packetSizeToMsgType map[int]uint32 + msgTypeToJunkSize map[uint32]int +) /* Type is an 8-bit field, followed by 3 nul bytes, * by marshalling the messages in little-endian byteorder diff --git a/device/peer.go b/device/peer.go index fdc0b86..6ccdc30 100644 --- a/device/peer.go +++ b/device/peer.go @@ -114,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error { + err := peer.SendBuffers(buffers) + if err == nil { + awg.PacketCounter.Add(uint64(len(buffers))) + return nil + } + + return err +} + func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -145,16 +155,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return err } -func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error { - err := peer.SendBuffers(buffers) - if err == nil { - awg.PacketCounter.Add(uint64(len(buffers))) - return nil - } - - return err -} - func (peer *Peer) String() string { // The awful goo that follows is identical to: // diff --git a/device/receive.go b/device/receive.go index 1ab7cd9..0844b79 100644 --- a/device/receive.go +++ b/device/receive.go @@ -157,12 +157,15 @@ func (device *Device) RoutineReceiveIncoming( msgType = binary.LittleEndian.Uint32(packet[:4]) } } else { - msgType = binary.LittleEndian.Uint32(packet[:4]) + transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize + msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4]) if msgType != MessageTransportType { // probably a junk packet device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) continue } + + packet = packet[transportJunkSize:] } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) diff --git a/device/send.go b/device/send.go index 305589a..234c391 100644 --- a/device/send.go +++ b/device/send.go @@ -124,9 +124,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } var sendBuffer [][]byte + // so only packet processed for cookie generation var junkedHeader []byte - if peer.device.version >= VersionAwg { var junks [][]byte if peer.device.version == VersionAwgSpecialHandshake { @@ -163,19 +163,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.awg.ASecMux.RLock() - if peer.device.awg.ASecCfg.InitPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.ASecCfg.InitPacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.InitPacketJunkSize) - if err != nil { - peer.device.log.Errorf("%v - %v", peer, err) - peer.device.awg.ASecMux.RUnlock() - return err - } - junkedHeader = writer.Bytes() + junkedHeader, err = peer.device.awg.CreateInitHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } - peer.device.awg.ASecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -211,22 +203,13 @@ func (peer *Peer) SendHandshakeResponse() error { peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var junkedHeader []byte - if peer.device.isAWG() { - peer.device.awg.ASecMux.RLock() - if peer.device.awg.ASecCfg.ResponsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.ASecCfg.ResponsePacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.ResponsePacketJunkSize) - if err != nil { - peer.device.awg.ASecMux.RUnlock() - peer.device.log.Errorf("%v - %v", peer, err) - return err - } - junkedHeader = writer.Bytes() - } - peer.device.awg.ASecMux.RUnlock() + + junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } + var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -269,11 +252,19 @@ func (device *Device) SendHandshakeCookie( return err } + junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk() + if err != nil { + device.log.Errorf("%v - %v", device, err) + return err + } + var buf [MessageCookieReplySize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) + + junkedHeader = append(junkedHeader, writer.Bytes()...) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint) return nil } @@ -593,6 +584,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true + junkedHeader, err := device.awg.CreateTransportHeaderJunk() + if err != nil { + device.log.Errorf("%v - %v", device, err) + continue + } + + elem.packet = append(junkedHeader, elem.packet...) } bufs = append(bufs, elem.packet) } @@ -604,6 +602,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) diff --git a/device/uapi.go b/device/uapi.go index 9f2718d..d1511cc 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -108,11 +108,17 @@ func (device *Device) IpcGetOperation(w io.Writer) error { if device.awg.ASecCfg.JunkPacketMaxSize != 0 { sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) } - if device.awg.ASecCfg.InitPacketJunkSize != 0 { - sendf("s1=%d", device.awg.ASecCfg.InitPacketJunkSize) + if device.awg.ASecCfg.InitHeaderJunkSize != 0 { + sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize) } - if device.awg.ASecCfg.ResponsePacketJunkSize != 0 { - sendf("s2=%d", device.awg.ASecCfg.ResponsePacketJunkSize) + if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 { + sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize) + } + if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize) + } + if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { + sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) } if device.awg.ASecCfg.InitPacketMagicHeader != 0 { sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) @@ -333,7 +339,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempAwg.ASecCfg.InitPacketJunkSize = initPacketJunkSize + tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize tempAwg.ASecCfg.IsSet = true case "s2": @@ -342,7 +348,25 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempAwg.ASecCfg.ResponsePacketJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s3": + cookieReplyPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size") + tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s4": + transportPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating transport_packet_junk_size") + tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize tempAwg.ASecCfg.IsSet = true case "h1":