From c38c3ed54fd5f1389904305272de0d75ac5955ce Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Wed, 9 Jul 2025 19:54:36 +0200 Subject: [PATCH] feat: working ranged magic headers --- device/device.go | 21 ++++++++++++--------- device/device_test.go | 8 ++++---- device/noise-protocol.go | 2 ++ device/receive.go | 8 +++++++- device/send.go | 2 +- go.mod | 1 + go.sum | 2 ++ 7 files changed, 29 insertions(+), 15 deletions(-) diff --git a/device/device.go b/device/device.go index 22b73d0..dbb975b 100644 --- a/device/device.go +++ b/device/device.go @@ -665,6 +665,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader } else { @@ -690,7 +691,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min - limits[3] = tempAwg.ASecCfg.UnderloadPacketMagicHeader + limits[3] = tempAwg.ASecCfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -704,6 +705,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { MessageTransportType: {}, } + device.awg.MagicHeaders = awg.NewLimits(limits) + // size will be different if same values if len(isSameHeaderMap) != 4 { errs = append(errs, ipcErrorf( @@ -717,8 +720,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ) } - device.awg.MagicHeaders = awg.NewLimits(limits) - newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize if newInitSize >= MaxSegmentSize { @@ -846,8 +847,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { return errors.Join(errs...) } -var ErrContinueLoop = errors.New("continue processing") - func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) { // TODO: // if awg.WaitResponse.ShouldWait.IsSet() { @@ -860,6 +859,9 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize] } junkSize := msgTypeToJunkSize[assumedMsgType] + fmt.Println(msgTypeToJunkSize) + fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize) + // transport size can align with other header types; // making sure we have the right msgType msgType, err = device.getMsgType(packet, junkSize) @@ -877,10 +879,11 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize] } func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { - msgType := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) - msgType, err := device.awg.GetLimitMin(msgType) + msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) + msgType, err := device.awg.GetLimitMin(msgTypeRange) + if err != nil { - return 0, fmt.Errorf("aSec: get limit min for message type %d: %w", msgType, err) + return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err) } return msgType, nil } @@ -890,7 +893,7 @@ func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMe msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize) if err != nil { - return 0, fmt.Errorf("aSec: get msg type: %w", err) + return 0, fmt.Errorf("get msg type: %w", err) } if msgType != MessageTransportType { diff --git a/device/device_test.go b/device/device_test.go index e3128e4..8966218 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -236,10 +236,10 @@ func TestAWGDevicePing(t *testing.T) { "s2", "18", "s3", "20", "s4", "25", - "h1", "123456", - "h2", "67543", - "h3", "123123", - "h4", "32345", + "h1", "123456-123500", + "h2", "67543-67550", + "h3", "123123-123200", + "h4", "32345-32350", ) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 57a677d..aa7b0e6 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -271,6 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { ) device.awg.ASecMux.RLock() + if msg.Type != MessageInitiationType { device.awg.ASecMux.RUnlock() return nil @@ -448,6 +449,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { device.awg.ASecMux.RLock() + if msg.Type != MessageResponseType { device.awg.ASecMux.RUnlock() return nil diff --git a/device/receive.go b/device/receive.go index d1ce3b6..41ffa2b 100644 --- a/device/receive.go +++ b/device/receive.go @@ -142,7 +142,7 @@ func (device *Device) RoutineReceiveIncoming( if device.isAWG() { msgType, err = device.Logic(size, &packet, bufsArrs[i]) if err != nil { - device.log.Verbosef("awg device logic: %w", err) + device.log.Verbosef("awg device logic: %v", err) continue } } else { @@ -378,6 +378,9 @@ func (device *Device) RoutineHandshake(id int) { goto skip } + // have to reassign msgType for ranged msgType to work + msg.Type = elem.msgType + // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { @@ -410,6 +413,9 @@ func (device *Device) RoutineHandshake(id int) { goto skip } + // have to reassign msgType for ranged msgType to work + msg.Type = elem.msgType + // consume response peer := device.ConsumeMessageResponse(&msg) diff --git a/device/send.go b/device/send.go index 7a4bb16..4d07066 100644 --- a/device/send.go +++ b/device/send.go @@ -537,7 +537,7 @@ func (device *Device) RoutineEncryption(id int) { msgType, err := device.awg.Get(DefaultMessageTransportType) if err != nil { - device.log.Errorf("Get message type for transport: %v", err) + device.log.Errorf("get message type for transport: %v", err) continue } binary.LittleEndian.PutUint32(fieldType, msgType) diff --git a/go.mod b/go.mod index b000b63..44587fa 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/tevino/abool v1.2.0 go.uber.org/atomic v1.11.0 golang.org/x/crypto v0.36.0 + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.37.0 golang.org/x/sys v0.31.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 diff --git a/go.sum b/go.sum index 2860376..3676ea5 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=