From 2749e6b043409863b7a7e7229a23d520477eccf2 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Tue, 8 Jul 2025 21:14:28 +0200
Subject: [PATCH] feat: first testable version of ranged magic headers
---
device/awg/awg.go | 60 ++++++++++++++++++++++++----------------
device/awg/util.go | 2 +-
device/cookie.go | 3 +-
device/cookie_test.go | 2 +-
device/device.go | 19 +++++++++----
device/noise-protocol.go | 15 ++++++++--
device/send.go | 16 +++++++++--
device/uapi.go | 53 ++++++++++++++++++++---------------
8 files changed, 111 insertions(+), 59 deletions(-)
diff --git a/device/awg/awg.go b/device/awg/awg.go
index 838e7ca..ef97313 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -3,7 +3,6 @@ package awg
import (
"bytes"
"fmt"
- "slices"
"strconv"
"strings"
"sync"
@@ -32,24 +31,29 @@ type aSecCfgType struct {
}
type Limit struct {
- Min uint32
- Max uint32
- HeaderType uint32
+ Min uint32
+ Max uint32
}
-func NewLimit(min, max, headerType uint32) (Limit, error) {
+func NewLimitSameValue(value uint32) Limit {
+ return Limit{
+ Min: value,
+ Max: value,
+ }
+}
+
+func NewLimit(min, max uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
- Min: min,
- Max: max,
- HeaderType: headerType,
+ Min: min,
+ Max: max,
}, nil
}
-func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
+func ParseMagicHeader(key, value string) (Limit, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
magicHeader, err := strconv.ParseUint(value, 10, 32)
@@ -57,7 +61,7 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error
return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
- return NewLimit(uint32(magicHeader), uint32(magicHeader), defaultHeaderType)
+ return NewLimit(uint32(magicHeader), uint32(magicHeader))
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
@@ -70,7 +74,7 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error
return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
- limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
+ limit, err := NewLimit(uint32(min), uint32(max))
if err != nil {
return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
@@ -78,19 +82,22 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error
return limit, nil
}
-type Limits []Limit
+type Limits struct {
+ Limits []Limit
+ randomGenerator PRNG[uint32]
+}
-func NewLimits(limits ...Limit) Limits {
- slices.SortFunc(limits, func(a, b Limit) int {
- if a.Min < b.Min {
- return -1
- } else if a.Min > b.Min {
- return 1
- }
- return 0
- })
+func NewLimits(limits []Limit) Limits {
+ // TODO: check if limits doesn't overlap
+ return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()}
+}
- return Limits(limits)
+func (l *Limits) Get(defaultMsgType uint32) (uint32, error) {
+ if defaultMsgType == 0 || defaultMsgType > 4 {
+ return 0, fmt.Errorf("invalid message type: %d", defaultMsgType)
+ }
+
+ return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil
}
type Protocol struct {
@@ -102,7 +109,7 @@ type Protocol struct {
HandshakeHandler SpecialHandshakeHandler
- limits Limits
+ MagicHeaders Limits
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
@@ -141,7 +148,8 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
}
func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) {
- for _, limit := range protocol.limits {
+ fmt.Println(protocol.MagicHeaders.Limits)
+ for _, limit := range protocol.MagicHeaders.Limits {
if limit.Min <= msgType && msgType <= limit.Max {
return limit.Min, nil
}
@@ -149,3 +157,7 @@ func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) {
return 0, fmt.Errorf("no limit found for message type: %d", msgType)
}
+
+func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) {
+ return protocol.MagicHeaders.Get(defaultMsgType)
+}
diff --git a/device/awg/util.go b/device/awg/util.go
index 5cb0caa..164e2b2 100644
--- a/device/awg/util.go
+++ b/device/awg/util.go
@@ -21,7 +21,7 @@ func NewPRNG[T constraints.Integer]() PRNG[T] {
}
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
- if min >= max {
+ if min > max {
panic("min must be less than max")
}
diff --git a/device/cookie.go b/device/cookie.go
index 876f05d..5dfb8f5 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -118,6 +118,7 @@ func (st *CookieChecker) CreateReply(
msg []byte,
recv uint32,
src []byte,
+ msgType uint32,
) (*MessageCookieReply, error) {
st.RLock()
@@ -153,7 +154,7 @@ func (st *CookieChecker) CreateReply(
smac1 := smac2 - blake2s.Size128
reply := new(MessageCookieReply)
- reply.Type = MessageCookieReplyType
+ reply.Type = msgType
reply.Receiver = recv
_, err := rand.Read(reply.Nonce[:])
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 4f1e50a..d076004 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
}
generator.AddMacs(msg)
- reply, err := checker.CreateReply(msg, 1377, src)
+ reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
if err != nil {
t.Fatal("Failed to create cookie reply:", err)
}
diff --git a/device/device.go b/device/device.go
index a0e04cb..22b73d0 100644
--- a/device/device.go
+++ b/device/device.go
@@ -648,24 +648,29 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
isASecOn = true
}
+ limits := make([]awg.Limit, 4)
+
if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
+ limits[0] = tempAwg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
+ limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType)
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min
+ limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
+ limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType)
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
@@ -673,9 +678,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min
+ limits[2] = tempAwg.ASecCfg.UnderloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
+ limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
@@ -683,9 +690,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min
+ limits[3] = tempAwg.ASecCfg.UnderloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
+ limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType)
}
isSameHeaderMap := map[uint32]struct{}{
@@ -708,6 +717,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
+ device.awg.MagicHeaders = awg.NewLimits(limits)
+
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
@@ -814,11 +825,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
device.awg.IsASecOn.SetTo(isASecOn)
- var err error
- device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
- if err != nil {
- errs = append(errs, err)
- }
+ device.awg.JunkCreator = awg.NewJunkCreator(device.awg.ASecCfg)
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index ff1ecc1..57a677d 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -206,8 +206,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
+ msgType, err := device.awg.Get(DefaultMessageInitiationType)
+ if err != nil {
+ device.awg.ASecMux.RUnlock()
+ return nil, fmt.Errorf("get message type: %w", err)
+ }
+
msg := MessageInitiation{
- Type: MessageInitiationType,
+ Type: msgType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.awg.ASecMux.RUnlock()
@@ -385,7 +391,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var msg MessageResponse
device.awg.ASecMux.RLock()
- msg.Type = MessageResponseType
+ msg.Type, err = device.awg.Get(DefaultMessageResponseType)
+ if err != nil {
+ device.awg.ASecMux.RUnlock()
+ return nil, fmt.Errorf("get message type: %w", err)
+ }
+
device.awg.ASecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
diff --git a/device/send.go b/device/send.go
index f0db894..7a4bb16 100644
--- a/device/send.go
+++ b/device/send.go
@@ -146,7 +146,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
- err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
+ peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {
@@ -242,10 +242,17 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
+ msgType, err := device.awg.Get(DefaultMessageCookieReplyType)
+ if err != nil {
+ device.log.Errorf("Get message type for cookie reply: %v", err)
+ return err
+ }
+
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
+ msgType,
)
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
@@ -528,7 +535,12 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ msgType, err := device.awg.Get(DefaultMessageTransportType)
+ if err != nil {
+ device.log.Errorf("Get message type for transport: %v", err)
+ continue
+ }
+ binary.LittleEndian.PutUint32(fieldType, msgType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
diff --git a/device/uapi.go b/device/uapi.go
index 1097e28..ce0152e 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -120,18 +120,19 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
}
- if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
- sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
- }
- if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
- sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
- }
- if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
- sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
- }
- if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
- sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
- }
+ // TODO:
+ // if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
+ // sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
+ // }
+ // if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
+ // sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
+ // }
+ // if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
+ // sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
+ // }
+ // if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
+ // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
+ // }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
@@ -370,33 +371,41 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
tempAwg.ASecCfg.IsSet = true
case "h1":
- awg.ParseMagicHeader(key, value, &tempAwg.ASecCfg.InitPacketMagicHeader)
+ magicHeader, err := awg.ParseMagicHeader(key, value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
+ }
+ tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader
tempAwg.ASecCfg.IsSet = true
case "h2":
- responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ magicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
+
+ tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader
tempAwg.ASecCfg.IsSet = true
case "h3":
- underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ magicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
+
+ tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader
tempAwg.ASecCfg.IsSet = true
case "h4":
- transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ magicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
+
+ tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader
tempAwg.ASecCfg.IsSet = true
+
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)