From c50499d50ec8ceb3e94bdb276bbd4ee93741ff6f Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Wed, 9 Jul 2025 20:35:02 +0200
Subject: [PATCH] chore: some cleanup
---
device/awg/awg.go | 142 ++++++++++---------------------------
device/awg/magic_header.go | 91 ++++++++++++++++++++++++
device/device.go | 22 +++---
device/noise-protocol.go | 4 +-
device/send.go | 9 +--
device/uapi.go | 16 ++---
6 files changed, 154 insertions(+), 130 deletions(-)
create mode 100644 device/awg/magic_header.go
diff --git a/device/awg/awg.go b/device/awg/awg.go
index ef97313..6138c5f 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -3,8 +3,6 @@ package awg
import (
"bytes"
"fmt"
- "strconv"
- "strings"
"sync"
"github.com/tevino/abool"
@@ -19,145 +17,81 @@ type aSecCfgType struct {
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
- // InitPacketMagicHeader uint32
- // ResponsePacketMagicHeader uint32
- // UnderloadPacketMagicHeader uint32
- // TransportPacketMagicHeader uint32
- InitPacketMagicHeader Limit
- ResponsePacketMagicHeader Limit
- UnderloadPacketMagicHeader Limit
- TransportPacketMagicHeader Limit
-}
-
-type Limit struct {
- Min uint32
- Max uint32
-}
-
-func NewLimitSameValue(value uint32) Limit {
- return Limit{
- Min: value,
- Max: value,
- }
-}
-
-func NewLimit(min, max uint32) (Limit, error) {
- if min > max {
- return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
- }
-
- return Limit{
- Min: min,
- Max: max,
- }, nil
-}
-
-func ParseMagicHeader(key, value string) (Limit, error) {
- splitLimits := strings.Split(value, "-")
- if len(splitLimits) != 2 {
- magicHeader, err := strconv.ParseUint(value, 10, 32)
- if err != nil {
- return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
- }
-
- return NewLimit(uint32(magicHeader), uint32(magicHeader))
- }
-
- min, err := strconv.ParseUint(splitLimits[0], 10, 32)
- if err != nil {
- return Limit{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
- }
-
- max, err := strconv.ParseUint(splitLimits[1], 10, 32)
- if err != nil {
- return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
- }
-
- limit, err := NewLimit(uint32(min), uint32(max))
- if err != nil {
- return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
- }
-
- return limit, nil
-}
-
-type Limits struct {
- Limits []Limit
- randomGenerator PRNG[uint32]
-}
-
-func NewLimits(limits []Limit) Limits {
- // TODO: check if limits doesn't overlap
- return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()}
-}
-
-func (l *Limits) Get(defaultMsgType uint32) (uint32, error) {
- if defaultMsgType == 0 || defaultMsgType > 4 {
- return 0, fmt.Errorf("invalid message type: %d", defaultMsgType)
- }
-
- return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil
+ InitPacketMagicHeader MagicHeader
+ ResponsePacketMagicHeader MagicHeader
+ UnderloadPacketMagicHeader MagicHeader
+ TransportPacketMagicHeader MagicHeader
}
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
- ASecMux sync.RWMutex
- ASecCfg aSecCfgType
- JunkCreator junkCreator
+ ASecMux sync.RWMutex
+ ASecCfg aSecCfgType
+ JunkCreator junkCreator
+ MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler
-
- MagicHeaders Limits
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
+ protocol.ASecMux.RLock()
+ defer protocol.ASecMux.RUnlock()
+
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
+ protocol.ASecMux.RLock()
+ defer protocol.ASecMux.RUnlock()
+
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
+ protocol.ASecMux.RLock()
+ defer protocol.ASecMux.RUnlock()
+
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
+ protocol.ASecMux.RLock()
+ defer protocol.ASecMux.RUnlock()
+
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
- var junk []byte
- protocol.ASecMux.RLock()
-
- if junkSize != 0 {
- buf := make([]byte, 0, junkSize+extraSize)
- writer := bytes.NewBuffer(buf[:0])
- err := protocol.JunkCreator.AppendJunk(writer, junkSize)
- if err != nil {
- protocol.ASecMux.RUnlock()
- return nil, err
- }
- junk = writer.Bytes()
+ if junkSize == 0 {
+ return nil, nil
}
- protocol.ASecMux.RUnlock()
+
+ var junk []byte
+ buf := make([]byte, 0, junkSize+extraSize)
+ writer := bytes.NewBuffer(buf[:0])
+
+ err := protocol.JunkCreator.AppendJunk(writer, junkSize)
+ if err != nil {
+ return nil, fmt.Errorf("append junk: %w", err)
+ }
+
+ junk = writer.Bytes()
return junk, nil
}
-func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) {
- fmt.Println(protocol.MagicHeaders.Limits)
- for _, limit := range protocol.MagicHeaders.Limits {
- if limit.Min <= msgType && msgType <= limit.Max {
+func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) {
+ for _, limit := range protocol.MagicHeaders.headers {
+ if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
return limit.Min, nil
}
}
- return 0, fmt.Errorf("no limit found for message type: %d", msgType)
+ return 0, fmt.Errorf("no limit for range: %d", msgTypeRange)
}
-func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) {
+func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.MagicHeaders.Get(defaultMsgType)
}
diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go
new file mode 100644
index 0000000..f7ad45c
--- /dev/null
+++ b/device/awg/magic_header.go
@@ -0,0 +1,91 @@
+package awg
+
+import (
+ "cmp"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+)
+
+type MagicHeader struct {
+ Min uint32
+ Max uint32
+}
+
+func NewMagicHeaderSameValue(value uint32) MagicHeader {
+ return MagicHeader{Min: value, Max: value}
+}
+
+func NewMagicHeader(min, max uint32) (MagicHeader, error) {
+ if min > max {
+ return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
+ }
+
+ return MagicHeader{Min: min, Max: max}, nil
+}
+
+func ParseMagicHeader(key, value string) (MagicHeader, error) {
+ splitLimits := strings.Split(value, "-")
+ if len(splitLimits) != 2 {
+ // if there is no hyphen, we treat it as single magic header value
+ magicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
+ }
+
+ return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
+ }
+
+ min, err := strconv.ParseUint(splitLimits[0], 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
+ }
+
+ max, err := strconv.ParseUint(splitLimits[1], 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
+ }
+
+ magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
+ }
+
+ return magicHeader, nil
+}
+
+type MagicHeaders struct {
+ headers []MagicHeader
+ randomGenerator PRNG[uint32]
+}
+
+func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) {
+ if len(magicHeaders) != 4 {
+ return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders)
+ }
+
+ sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int {
+ return cmp.Compare(lhs.Min, rhs.Min)
+ })
+
+ for i := range 3 {
+ if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min {
+ return MagicHeaders{}, fmt.Errorf(
+ "magic headers shouldn't overlap; %v > %v",
+ sortedMagicHeaders[i-1].Min,
+ sortedMagicHeaders[i].Min,
+ )
+ }
+ }
+
+ return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil
+}
+
+func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
+ if defaultMsgType == 0 || defaultMsgType > 4 {
+ return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
+ }
+
+ return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil
+}
diff --git a/device/device.go b/device/device.go
index dbb975b..1f1c035 100644
--- a/device/device.go
+++ b/device/device.go
@@ -648,7 +648,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
isASecOn = true
}
- limits := make([]awg.Limit, 4)
+ limits := make([]awg.MagicHeader, 4)
if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
@@ -659,7 +659,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
- limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType)
+ limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
@@ -671,7 +671,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
- limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType)
+ limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
@@ -683,7 +683,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
- limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType)
+ limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
@@ -695,7 +695,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
- limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType)
+ limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
isSameHeaderMap := map[uint32]struct{}{
@@ -705,7 +705,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {},
}
- device.awg.MagicHeaders = awg.NewLimits(limits)
+ var err error
+ device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
+ if err != nil {
+ errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
+ }
// size will be different if same values
if len(isSameHeaderMap) != 4 {
@@ -859,8 +863,6 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
junkSize := msgTypeToJunkSize[assumedMsgType]
- fmt.Println(msgTypeToJunkSize)
- fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize)
// transport size can align with other header types;
// making sure we have the right msgType
@@ -875,6 +877,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
device.log.Verbosef("transport packet lined up with another msg type")
+
return device.handleTransport(size, packet, bufsArrs)
}
@@ -883,8 +886,9 @@ func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgType, err := device.awg.GetLimitMin(msgTypeRange)
if err != nil {
- return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err)
+ return 0, fmt.Errorf("get limit min: %w", err)
}
+
return msgType, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index aa7b0e6..bce2d2f 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -206,7 +206,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
- msgType, err := device.awg.Get(DefaultMessageInitiationType)
+ msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
@@ -392,7 +392,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var msg MessageResponse
device.awg.ASecMux.RLock()
- msg.Type, err = device.awg.Get(DefaultMessageResponseType)
+ msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil {
device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
diff --git a/device/send.go b/device/send.go
index 4d07066..b87dda2 100644
--- a/device/send.go
+++ b/device/send.go
@@ -149,11 +149,6 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
- if err != nil {
- peer.device.log.Errorf("%v - %v", peer, err)
- return err
- }
-
if len(junks) > 0 {
err = peer.SendBuffers(junks)
@@ -242,7 +237,7 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
- msgType, err := device.awg.Get(DefaultMessageCookieReplyType)
+ msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType)
if err != nil {
device.log.Errorf("Get message type for cookie reply: %v", err)
return err
@@ -535,7 +530,7 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
- msgType, err := device.awg.Get(DefaultMessageTransportType)
+ msgType, err := device.awg.GetMsgType(DefaultMessageTransportType)
if err != nil {
device.log.Errorf("get message type for transport: %v", err)
continue
diff --git a/device/uapi.go b/device/uapi.go
index ce0152e..d99a4ee 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -371,39 +371,39 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
tempAwg.ASecCfg.IsSet = true
case "h1":
- magicHeader, err := awg.ParseMagicHeader(key, value)
+ initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader
+ tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h2":
- magicHeader, err := awg.ParseMagicHeader(key, value)
+ responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader
+ tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h3":
- magicHeader, err := awg.ParseMagicHeader(key, value)
+ cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader
+ tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h4":
- magicHeader, err := awg.ParseMagicHeader(key, value)
+ transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader
+ tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":