chore: rename variables

This commit is contained in:
Mark Puha 2025-07-10 19:56:03 +02:00
parent be20e77077
commit 6992e18755
8 changed files with 206 additions and 197 deletions

View file

@ -8,7 +8,7 @@ import (
"github.com/tevino/abool"
)
type aSecCfgType struct {
type Cfg struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
@ -25,42 +25,42 @@ type aSecCfgType struct {
}
type Protocol struct {
IsASecOn abool.AtomicBool
IsOn abool.AtomicBool
// TODO: revision the need of the mutex
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
Mux sync.RWMutex
Cfg Cfg
JunkCreator JunkCreator
MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0)
return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0)
return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0)
return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
protocol.ASecMux.RLock()
defer protocol.ASecMux.RUnlock()
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
@ -68,7 +68,6 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
return nil, nil
}
var junk []byte
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
@ -77,19 +76,17 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
return nil, fmt.Errorf("append junk: %w", err)
}
junk = writer.Bytes()
return junk, nil
return writer.Bytes(), nil
}
func (protocol *Protocol) GetMagicHeaderMinFor(msgTypeRange uint32) (uint32, error) {
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, limit := range protocol.MagicHeaders.headerValues {
if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
if limit.Min <= msgType && msgType <= limit.Max {
return limit.Min, nil
}
}
return 0, fmt.Errorf("no header for range: %d", msgTypeRange)
return 0, fmt.Errorf("no header for value: %d", msgType)
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {

View file

@ -5,23 +5,23 @@ import (
"fmt"
)
type junkCreator struct {
aSecCfg aSecCfgType
type JunkCreator struct {
cfg Cfg
randomGenerator PRNG[int]
}
// TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) junkCreator {
return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()}
func NewJunkCreator(cfg Cfg) JunkCreator {
return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
if jc.aSecCfg.JunkPacketCount == 0 {
// Should be called with awg mux RLocked
func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
if jc.cfg.JunkPacketCount == 0 {
return
}
for range jc.aSecCfg.JunkPacketCount {
for range jc.cfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk := jc.randomJunkWithSize(packetSize)
*junks = append(*junks, junk)
@ -29,13 +29,13 @@ func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
return
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize)
// Should be called with awg mux RLocked
func (jc *JunkCreator) randomPacketSize() int {
return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
// Should be called with awg mux RLocked
func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk := jc.randomJunkWithSize(size)
_, err := writer.Write(headerJunk)
if err != nil {
@ -44,7 +44,7 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) []byte {
// Should be called with awg mux RLocked
func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
return jc.randomGenerator.ReadSize(size)
}

View file

@ -6,8 +6,8 @@ import (
"testing"
)
func setUpJunkCreator() junkCreator {
jc := NewJunkCreator(aSecCfgType{
func setUpJunkCreator() JunkCreator {
jc := NewJunkCreator(Cfg{
IsSet: true,
JunkPacketCount: 5,
JunkPacketMinSize: 500,
@ -27,7 +27,7 @@ func setUpJunkCreator() junkCreator {
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
got := make([][]byte, 0, jc.cfg.JunkPacketCount)
jc.CreateJunkPackets(&got)
seen := make(map[string]bool)
for _, junk := range got {
@ -62,13 +62,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
jc := setUpJunkCreator()
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
got > jc.aSecCfg.JunkPacketMaxSize {
if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
got > jc.cfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.aSecCfg.JunkPacketMinSize,
jc.aSecCfg.JunkPacketMaxSize,
jc.cfg.JunkPacketMinSize,
jc.cfg.JunkPacketMaxSize,
)
}
})

View file

@ -593,105 +593,105 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
var errs []error
isASecOn := false
device.awg.ASecMux.Lock()
if tempAwg.ASecCfg.JunkPacketCount < 0 {
device.awg.Mux.Lock()
if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
if tempAwg.ASecCfg.JunkPacketCount != 0 {
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.Cfg.JunkPacketCount != 0 {
isASecOn = true
}
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.Cfg.JunkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.ASecCfg.JunkPacketCount > 0 &&
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
if device.awg.Cfg.JunkPacketCount > 0 &&
tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
}
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.ASecCfg.JunkPacketMinSize = 0
device.awg.ASecCfg.JunkPacketMaxSize = 1
if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.Cfg.JunkPacketMinSize = 0
device.awg.Cfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
tempAwg.Cfg.JunkPacketMaxSize,
MaxSegmentSize,
))
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
} else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
tempAwg.ASecCfg.JunkPacketMinSize,
tempAwg.Cfg.JunkPacketMaxSize,
tempAwg.Cfg.JunkPacketMinSize,
))
} else {
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
}
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isASecOn = true
}
limits := make([]awg.MagicHeader, 4)
if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
limits[0] = tempAwg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min
device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader
limits[0] = tempAwg.Cfg.InitPacketMagicHeader
MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min
limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader
device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min
limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min
limits[2] = tempAwg.ASecCfg.UnderloadPacketMagicHeader
device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min
limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min
limits[3] = tempAwg.ASecCfg.TransportPacketMagicHeader
device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader
MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min
limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@ -724,75 +724,75 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.InitHeaderJunkSize,
tempAwg.Cfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
}
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isASecOn = true
}
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.ResponseHeaderJunkSize,
tempAwg.Cfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
}
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
tempAwg.Cfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
}
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.TransportHeaderJunkSize,
tempAwg.Cfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
}
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
@ -815,10 +815,10 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
} else {
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
}
packetSizeToMsgType = map[int]uint32{
@ -829,8 +829,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
}
device.awg.IsASecOn.SetTo(isASecOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.ASecCfg)
device.awg.IsOn.SetTo(isASecOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
@ -838,78 +838,89 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
device.awg.ASecMux.Unlock()
device.awg.Mux.Unlock()
return errors.Join(errs...)
}
func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) {
func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
assumedMsgType, hasAssumedType := packetSizeToMsgType[size]
if !hasAssumedType {
return device.handleTransport(size, packet, bufsArrs)
}
expectedMsgType, isKnownSize := packetSizeToMsgType[size]
if !isKnownSize {
msgType, err := device.handleTransport(size, packet, buffer)
junkSize := msgTypeToJunkSize[assumedMsgType]
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
// transport size can align with other header types;
// making sure we have the right msgType
msgType, err = device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("aSec: get msg type: %w", err)
}
if msgType == assumedMsgType {
*packet = (*packet)[junkSize:]
return msgType, nil
}
device.log.Verbosef("transport packet lined up with another msg type")
return device.handleTransport(size, packet, bufsArrs)
}
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeRange)
junkSize := msgTypeToJunkSize[expectedMsgType]
// transport size can align with other header types;
// making sure we have the right actualMsgType
actualMsgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get limit min: %w", err)
return 0, fmt.Errorf("get msg type: %w", err)
}
if actualMsgType == expectedMsgType {
*packet = (*packet)[junkSize:]
return actualMsgType, nil
}
device.log.Verbosef("awg: transport packet lined up with another msg type")
msgType, err := device.handleTransport(size, packet, buffer)
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (uint32, error) {
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize)
if err != nil {
return 0, fmt.Errorf("get magic header min: %w", err)
}
return msgType, nil
}
func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
junkSize := device.awg.Cfg.TransportHeaderJunkSize
msgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if msgType != MessageTransportType {
// probably a junk packet
return 0, fmt.Errorf("aSec: Received message with unknown type: %d", msgType)
return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
}
if transportJunkSize > 0 {
// remove junk from bufsArrs by shifting the packet
if junkSize > 0 {
// remove junk from buffer by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy((*bufsArrs)[:size], (*packet)[transportJunkSize:])
size -= transportJunkSize
copy((*buffer)[:size], (*packet)[junkSize:])
size -= junkSize
// need to reinitialize packet as well
(*packet) = (*packet)[:size]
}

View file

@ -205,10 +205,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
@ -216,7 +216,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
Type: msgType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -270,13 +270,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
if msg.Type != MessageInitiationType {
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -391,14 +391,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil {
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -448,13 +448,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
if msg.Type != MessageResponseType {
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
// lookup handshake by receiver

View file

@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -140,9 +140,10 @@ func (device *Device) RoutineReceiveIncoming(
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAWG() {
msgType, err = device.Logic(size, &packet, bufsArrs[i])
msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
if err != nil {
device.log.Verbosef("awg device logic: %v", err)
device.log.Verbosef("awg: process packet: %v", err)
continue
}
} else {
@ -232,7 +233,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@ -291,7 +292,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.awg.ASecMux.RLock()
device.awg.Mux.RLock()
// handle cookie fields and ratelimiting
@ -449,7 +450,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.awg.ASecMux.RUnlock()
device.awg.Mux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -130,7 +130,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.ASecMux.RLock()
peer.device.awg.Mux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
@ -141,13 +141,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.ASecMux.RUnlock()
peer.device.awg.Mux.RUnlock()
} else {
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
peer.device.awg.Mux.RLock()
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
peer.device.awg.Mux.RUnlock()
if len(junks) > 0 {
err = peer.SendBuffers(junks)

View file

@ -99,26 +99,26 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAWG() {
if device.awg.ASecCfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
if device.awg.Cfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
}
if device.awg.ASecCfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
if device.awg.Cfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
}
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
if device.awg.Cfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
}
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
if device.awg.Cfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
}
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
}
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
}
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
// TODO:
// if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
@ -313,8 +313,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.JunkPacketCount = junkPacketCount
tempAwg.Cfg.IsSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@ -322,8 +322,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.Cfg.IsSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@ -331,8 +331,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.Cfg.IsSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@ -340,8 +340,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.Cfg.IsSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@ -349,8 +349,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.Cfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
@ -358,8 +358,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.Cfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
@ -367,8 +367,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.Cfg.IsSet = true
case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value)
@ -376,8 +376,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader
tempAwg.Cfg.IsSet = true
case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
@ -385,8 +385,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.Cfg.IsSet = true
case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
@ -394,8 +394,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true
case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
@ -403,8 +403,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {