chore: some cleanup

This commit is contained in:
Mark Puha 2025-07-09 20:35:02 +02:00
parent c38c3ed54f
commit c50499d50e
6 changed files with 154 additions and 130 deletions

View file

@ -3,8 +3,6 @@ package awg
import (
"bytes"
"fmt"
"strconv"
"strings"
"sync"
"github.com/tevino/abool"
@ -19,145 +17,81 @@ type aSecCfgType struct {
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
// InitPacketMagicHeader uint32
// ResponsePacketMagicHeader uint32
// UnderloadPacketMagicHeader uint32
// TransportPacketMagicHeader uint32
InitPacketMagicHeader Limit
ResponsePacketMagicHeader Limit
UnderloadPacketMagicHeader Limit
TransportPacketMagicHeader Limit
}
type Limit struct {
Min uint32
Max uint32
}
func NewLimitSameValue(value uint32) Limit {
return Limit{
Min: value,
Max: value,
}
}
func NewLimit(min, max uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
Min: min,
Max: max,
}, nil
}
func ParseMagicHeader(key, value string) (Limit, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewLimit(uint32(magicHeader), uint32(magicHeader))
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
}
max, err := strconv.ParseUint(splitLimits[1], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
limit, err := NewLimit(uint32(min), uint32(max))
if err != nil {
return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
return limit, nil
}
type Limits struct {
Limits []Limit
randomGenerator PRNG[uint32]
}
func NewLimits(limits []Limit) Limits {
// TODO: check if limits doesn't overlap
return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()}
}
func (l *Limits) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid message type: %d", defaultMsgType)
}
return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil
InitPacketMagicHeader MagicHeader
ResponsePacketMagicHeader MagicHeader
UnderloadPacketMagicHeader MagicHeader
TransportPacketMagicHeader MagicHeader
}
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler
MagicHeaders Limits
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
var junk []byte
protocol.ASecMux.RLock()
if junkSize != 0 {
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
protocol.ASecMux.RUnlock()
return nil, err
}
junk = writer.Bytes()
if junkSize == 0 {
return nil, nil
}
protocol.ASecMux.RUnlock()
var junk []byte
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
return nil, fmt.Errorf("append junk: %w", err)
}
junk = writer.Bytes()
return junk, nil
}
func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) {
fmt.Println(protocol.MagicHeaders.Limits)
for _, limit := range protocol.MagicHeaders.Limits {
if limit.Min <= msgType && msgType <= limit.Max {
func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) {
for _, limit := range protocol.MagicHeaders.headers {
if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
return limit.Min, nil
}
}
return 0, fmt.Errorf("no limit found for message type: %d", msgType)
return 0, fmt.Errorf("no limit for range: %d", msgTypeRange)
}
func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) {
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.MagicHeaders.Get(defaultMsgType)
}

View file

@ -0,0 +1,91 @@
package awg
import (
"cmp"
"fmt"
"slices"
"strconv"
"strings"
)
type MagicHeader struct {
Min uint32
Max uint32
}
func NewMagicHeaderSameValue(value uint32) MagicHeader {
return MagicHeader{Min: value, Max: value}
}
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
if min > max {
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return MagicHeader{Min: min, Max: max}, nil
}
func ParseMagicHeader(key, value string) (MagicHeader, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
// if there is no hyphen, we treat it as single magic header value
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
}
max, err := strconv.ParseUint(splitLimits[1], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
if err != nil {
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
return magicHeader, nil
}
type MagicHeaders struct {
headers []MagicHeader
randomGenerator PRNG[uint32]
}
func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) {
if len(magicHeaders) != 4 {
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders)
}
sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int {
return cmp.Compare(lhs.Min, rhs.Min)
})
for i := range 3 {
if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min {
return MagicHeaders{}, fmt.Errorf(
"magic headers shouldn't overlap; %v > %v",
sortedMagicHeaders[i-1].Min,
sortedMagicHeaders[i].Min,
)
}
}
return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil
}

View file

@ -648,7 +648,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
isASecOn = true
}
limits := make([]awg.Limit, 4)
limits := make([]awg.MagicHeader, 4)
if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
@ -659,7 +659,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType)
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
@ -671,7 +671,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType)
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
@ -683,7 +683,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType)
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
@ -695,7 +695,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType)
limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
isSameHeaderMap := map[uint32]struct{}{
@ -705,7 +705,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {},
}
device.awg.MagicHeaders = awg.NewLimits(limits)
var err error
device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
// size will be different if same values
if len(isSameHeaderMap) != 4 {
@ -859,8 +863,6 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
junkSize := msgTypeToJunkSize[assumedMsgType]
fmt.Println(msgTypeToJunkSize)
fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize)
// transport size can align with other header types;
// making sure we have the right msgType
@ -875,6 +877,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
}
device.log.Verbosef("transport packet lined up with another msg type")
return device.handleTransport(size, packet, bufsArrs)
}
@ -883,8 +886,9 @@ func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgType, err := device.awg.GetLimitMin(msgTypeRange)
if err != nil {
return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err)
return 0, fmt.Errorf("get limit min: %w", err)
}
return msgType, nil
}

View file

@ -206,7 +206,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
msgType, err := device.awg.Get(DefaultMessageInitiationType)
msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
@ -392,7 +392,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var msg MessageResponse
device.awg.ASecMux.RLock()
msg.Type, err = device.awg.Get(DefaultMessageResponseType)
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil {
device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)

View file

@ -149,11 +149,6 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
@ -242,7 +237,7 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
msgType, err := device.awg.Get(DefaultMessageCookieReplyType)
msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType)
if err != nil {
device.log.Errorf("Get message type for cookie reply: %v", err)
return err
@ -535,7 +530,7 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
msgType, err := device.awg.Get(DefaultMessageTransportType)
msgType, err := device.awg.GetMsgType(DefaultMessageTransportType)
if err != nil {
device.log.Errorf("get message type for transport: %v", err)
continue

View file

@ -371,39 +371,39 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
tempAwg.ASecCfg.IsSet = true
case "h1":
magicHeader, err := awg.ParseMagicHeader(key, value)
initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader
tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h2":
magicHeader, err := awg.ParseMagicHeader(key, value)
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader
tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h3":
magicHeader, err := awg.ParseMagicHeader(key, value)
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader
tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.ASecCfg.IsSet = true
case "h4":
magicHeader, err := awg.ParseMagicHeader(key, value)
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader
tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":