From c38c3ed54fd5f1389904305272de0d75ac5955ce Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Wed, 9 Jul 2025 19:54:36 +0200
Subject: [PATCH] feat: working ranged magic headers
---
device/device.go | 21 ++++++++++++---------
device/device_test.go | 8 ++++----
device/noise-protocol.go | 2 ++
device/receive.go | 8 +++++++-
device/send.go | 2 +-
go.mod | 1 +
go.sum | 2 ++
7 files changed, 29 insertions(+), 15 deletions(-)
diff --git a/device/device.go b/device/device.go
index 22b73d0..dbb975b 100644
--- a/device/device.go
+++ b/device/device.go
@@ -665,6 +665,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
+ device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min
limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader
} else {
@@ -690,7 +691,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min
- limits[3] = tempAwg.ASecCfg.UnderloadPacketMagicHeader
+ limits[3] = tempAwg.ASecCfg.TransportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@@ -704,6 +705,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {},
}
+ device.awg.MagicHeaders = awg.NewLimits(limits)
+
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
@@ -717,8 +720,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
- device.awg.MagicHeaders = awg.NewLimits(limits)
-
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
@@ -846,8 +847,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
return errors.Join(errs...)
}
-var ErrContinueLoop = errors.New("continue processing")
-
func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
@@ -860,6 +859,9 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
junkSize := msgTypeToJunkSize[assumedMsgType]
+ fmt.Println(msgTypeToJunkSize)
+ fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize)
+
// transport size can align with other header types;
// making sure we have the right msgType
msgType, err = device.getMsgType(packet, junkSize)
@@ -877,10 +879,11 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
- msgType := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
- msgType, err := device.awg.GetLimitMin(msgType)
+ msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
+ msgType, err := device.awg.GetLimitMin(msgTypeRange)
+
if err != nil {
- return 0, fmt.Errorf("aSec: get limit min for message type %d: %w", msgType, err)
+ return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err)
}
return msgType, nil
}
@@ -890,7 +893,7 @@ func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMe
msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize)
if err != nil {
- return 0, fmt.Errorf("aSec: get msg type: %w", err)
+ return 0, fmt.Errorf("get msg type: %w", err)
}
if msgType != MessageTransportType {
diff --git a/device/device_test.go b/device/device_test.go
index e3128e4..8966218 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -236,10 +236,10 @@ func TestAWGDevicePing(t *testing.T) {
"s2", "18",
"s3", "20",
"s4", "25",
- "h1", "123456",
- "h2", "67543",
- "h3", "123123",
- "h4", "32345",
+ "h1", "123456-123500",
+ "h2", "67543-67550",
+ "h3", "123123-123200",
+ "h4", "32345-32350",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 57a677d..aa7b0e6 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -271,6 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
)
device.awg.ASecMux.RLock()
+
if msg.Type != MessageInitiationType {
device.awg.ASecMux.RUnlock()
return nil
@@ -448,6 +449,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.ASecMux.RLock()
+
if msg.Type != MessageResponseType {
device.awg.ASecMux.RUnlock()
return nil
diff --git a/device/receive.go b/device/receive.go
index d1ce3b6..41ffa2b 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -142,7 +142,7 @@ func (device *Device) RoutineReceiveIncoming(
if device.isAWG() {
msgType, err = device.Logic(size, &packet, bufsArrs[i])
if err != nil {
- device.log.Verbosef("awg device logic: %w", err)
+ device.log.Verbosef("awg device logic: %v", err)
continue
}
} else {
@@ -378,6 +378,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip
}
+ // have to reassign msgType for ranged msgType to work
+ msg.Type = elem.msgType
+
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
@@ -410,6 +413,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip
}
+ // have to reassign msgType for ranged msgType to work
+ msg.Type = elem.msgType
+
// consume response
peer := device.ConsumeMessageResponse(&msg)
diff --git a/device/send.go b/device/send.go
index 7a4bb16..4d07066 100644
--- a/device/send.go
+++ b/device/send.go
@@ -537,7 +537,7 @@ func (device *Device) RoutineEncryption(id int) {
msgType, err := device.awg.Get(DefaultMessageTransportType)
if err != nil {
- device.log.Errorf("Get message type for transport: %v", err)
+ device.log.Errorf("get message type for transport: %v", err)
continue
}
binary.LittleEndian.PutUint32(fieldType, msgType)
diff --git a/go.mod b/go.mod
index b000b63..44587fa 100644
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
github.com/tevino/abool v1.2.0
go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.36.0
+ golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
diff --git a/go.sum b/go.sum
index 2860376..3676ea5 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=