chore: some cleanup

This commit is contained in:
Mark Puha 2025-07-09 20:35:02 +02:00
parent c38c3ed54f
commit c50499d50e
6 changed files with 154 additions and 130 deletions

View file

@ -3,8 +3,6 @@ package awg
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -19,145 +17,81 @@ type aSecCfgType struct {
ResponseHeaderJunkSize int ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int TransportHeaderJunkSize int
// InitPacketMagicHeader uint32
// ResponsePacketMagicHeader uint32
// UnderloadPacketMagicHeader uint32
// TransportPacketMagicHeader uint32
InitPacketMagicHeader Limit InitPacketMagicHeader MagicHeader
ResponsePacketMagicHeader Limit ResponsePacketMagicHeader MagicHeader
UnderloadPacketMagicHeader Limit UnderloadPacketMagicHeader MagicHeader
TransportPacketMagicHeader Limit TransportPacketMagicHeader MagicHeader
}
type Limit struct {
Min uint32
Max uint32
}
func NewLimitSameValue(value uint32) Limit {
return Limit{
Min: value,
Max: value,
}
}
func NewLimit(min, max uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
Min: min,
Max: max,
}, nil
}
func ParseMagicHeader(key, value string) (Limit, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewLimit(uint32(magicHeader), uint32(magicHeader))
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
}
max, err := strconv.ParseUint(splitLimits[1], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
limit, err := NewLimit(uint32(min), uint32(max))
if err != nil {
return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
return limit, nil
}
type Limits struct {
Limits []Limit
randomGenerator PRNG[uint32]
}
func NewLimits(limits []Limit) Limits {
// TODO: check if limits doesn't overlap
return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()}
}
func (l *Limits) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid message type: %d", defaultMsgType)
}
return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil
} }
type Protocol struct { type Protocol struct {
IsASecOn abool.AtomicBool IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex // TODO: revision the need of the mutex
ASecMux sync.RWMutex ASecMux sync.RWMutex
ASecCfg aSecCfgType ASecCfg aSecCfgType
JunkCreator junkCreator JunkCreator junkCreator
MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler HandshakeHandler SpecialHandshakeHandler
MagicHeaders Limits
} }
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0) return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0) return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0) return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
} }
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) { func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
var junk []byte if junkSize == 0 {
protocol.ASecMux.RLock() return nil, nil
if junkSize != 0 {
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
protocol.ASecMux.RUnlock()
return nil, err
}
junk = writer.Bytes()
} }
protocol.ASecMux.RUnlock()
var junk []byte
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
return nil, fmt.Errorf("append junk: %w", err)
}
junk = writer.Bytes()
return junk, nil return junk, nil
} }
func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) { func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) {
fmt.Println(protocol.MagicHeaders.Limits) for _, limit := range protocol.MagicHeaders.headers {
for _, limit := range protocol.MagicHeaders.Limits { if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
if limit.Min <= msgType && msgType <= limit.Max {
return limit.Min, nil return limit.Min, nil
} }
} }
return 0, fmt.Errorf("no limit found for message type: %d", msgType) return 0, fmt.Errorf("no limit for range: %d", msgTypeRange)
} }
func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) { func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.MagicHeaders.Get(defaultMsgType) return protocol.MagicHeaders.Get(defaultMsgType)
} }

View file

@ -0,0 +1,91 @@
package awg
import (
"cmp"
"fmt"
"slices"
"strconv"
"strings"
)
type MagicHeader struct {
Min uint32
Max uint32
}
func NewMagicHeaderSameValue(value uint32) MagicHeader {
return MagicHeader{Min: value, Max: value}
}
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
if min > max {
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return MagicHeader{Min: min, Max: max}, nil
}
func ParseMagicHeader(key, value string) (MagicHeader, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
// if there is no hyphen, we treat it as single magic header value
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
}
max, err := strconv.ParseUint(splitLimits[1], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
if err != nil {
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
return magicHeader, nil
}
type MagicHeaders struct {
headers []MagicHeader
randomGenerator PRNG[uint32]
}
func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) {
if len(magicHeaders) != 4 {
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders)
}
sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int {
return cmp.Compare(lhs.Min, rhs.Min)
})
for i := range 3 {
if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min {
return MagicHeaders{}, fmt.Errorf(
"magic headers shouldn't overlap; %v > %v",
sortedMagicHeaders[i-1].Min,
sortedMagicHeaders[i].Min,
)
}
}
return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil
}

View file

@ -648,7 +648,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
isASecOn = true isASecOn = true
} }
limits := make([]awg.Limit, 4) limits := make([]awg.MagicHeader, 4)
if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 { if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true isASecOn = true
@ -659,7 +659,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else { } else {
device.log.Verbosef("UAPI: Using default init type") device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType) limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
} }
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
@ -671,7 +671,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else { } else {
device.log.Verbosef("UAPI: Using default response type") device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType) limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
} }
if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 { if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
@ -683,7 +683,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else { } else {
device.log.Verbosef("UAPI: Using default underload type") device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType) limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
} }
if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 { if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
@ -695,7 +695,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} else { } else {
device.log.Verbosef("UAPI: Using default transport type") device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType MessageTransportType = DefaultMessageTransportType
limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType) limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
} }
isSameHeaderMap := map[uint32]struct{}{ isSameHeaderMap := map[uint32]struct{}{
@ -705,7 +705,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {}, MessageTransportType: {},
} }
device.awg.MagicHeaders = awg.NewLimits(limits) var err error
device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
// size will be different if same values // size will be different if same values
if len(isSameHeaderMap) != 4 { if len(isSameHeaderMap) != 4 {
@ -859,8 +863,6 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
} }
junkSize := msgTypeToJunkSize[assumedMsgType] junkSize := msgTypeToJunkSize[assumedMsgType]
fmt.Println(msgTypeToJunkSize)
fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize)
// transport size can align with other header types; // transport size can align with other header types;
// making sure we have the right msgType // making sure we have the right msgType
@ -875,6 +877,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
} }
device.log.Verbosef("transport packet lined up with another msg type") device.log.Verbosef("transport packet lined up with another msg type")
return device.handleTransport(size, packet, bufsArrs) return device.handleTransport(size, packet, bufsArrs)
} }
@ -883,8 +886,9 @@ func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgType, err := device.awg.GetLimitMin(msgTypeRange) msgType, err := device.awg.GetLimitMin(msgTypeRange)
if err != nil { if err != nil {
return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err) return 0, fmt.Errorf("get limit min: %w", err)
} }
return msgType, nil return msgType, nil
} }

View file

@ -206,7 +206,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock() device.awg.ASecMux.RLock()
msgType, err := device.awg.Get(DefaultMessageInitiationType) msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil { if err != nil {
device.awg.ASecMux.RUnlock() device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err) return nil, fmt.Errorf("get message type: %w", err)
@ -392,7 +392,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var msg MessageResponse var msg MessageResponse
device.awg.ASecMux.RLock() device.awg.ASecMux.RLock()
msg.Type, err = device.awg.Get(DefaultMessageResponseType) msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil { if err != nil {
device.awg.ASecMux.RUnlock() device.awg.ASecMux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err) return nil, fmt.Errorf("get message type: %w", err)

View file

@ -149,11 +149,6 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.awg.JunkCreator.CreateJunkPackets(&junks) peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 { if len(junks) > 0 {
err = peer.SendBuffers(junks) err = peer.SendBuffers(junks)
@ -242,7 +237,7 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
msgType, err := device.awg.Get(DefaultMessageCookieReplyType) msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType)
if err != nil { if err != nil {
device.log.Errorf("Get message type for cookie reply: %v", err) device.log.Errorf("Get message type for cookie reply: %v", err)
return err return err
@ -535,7 +530,7 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8] fieldReceiver := header[4:8]
fieldNonce := header[8:16] fieldNonce := header[8:16]
msgType, err := device.awg.Get(DefaultMessageTransportType) msgType, err := device.awg.GetMsgType(DefaultMessageTransportType)
if err != nil { if err != nil {
device.log.Errorf("get message type for transport: %v", err) device.log.Errorf("get message type for transport: %v", err)
continue continue

View file

@ -371,39 +371,39 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
tempAwg.ASecCfg.IsSet = true tempAwg.ASecCfg.IsSet = true
case "h1": case "h1":
magicHeader, err := awg.ParseMagicHeader(key, value) initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader
tempAwg.ASecCfg.IsSet = true tempAwg.ASecCfg.IsSet = true
case "h2": case "h2":
magicHeader, err := awg.ParseMagicHeader(key, value) responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.ASecCfg.IsSet = true tempAwg.ASecCfg.IsSet = true
case "h3": case "h3":
magicHeader, err := awg.ParseMagicHeader(key, value) cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.ASecCfg.IsSet = true tempAwg.ASecCfg.IsSet = true
case "h4": case "h4":
magicHeader, err := awg.ParseMagicHeader(key, value) transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.ASecCfg.IsSet = true tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5": case "i1", "i2", "i3", "i4", "i5":