From 6992e187552d434cb67e2f3f0c178cdbaff39cee Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 10 Jul 2025 19:56:03 +0200
Subject: [PATCH] chore: rename variables
---
device/awg/awg.go | 45 ++++----
device/awg/junk_creator.go | 30 ++---
device/awg/junk_creator_test.go | 14 +--
device/device.go | 195 +++++++++++++++++---------------
device/noise-protocol.go | 24 ++--
device/receive.go | 13 ++-
device/send.go | 10 +-
device/uapi.go | 72 ++++++------
8 files changed, 206 insertions(+), 197 deletions(-)
diff --git a/device/awg/awg.go b/device/awg/awg.go
index 336617e..29a4b62 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -8,7 +8,7 @@ import (
"github.com/tevino/abool"
)
-type aSecCfgType struct {
+type Cfg struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
@@ -25,42 +25,42 @@ type aSecCfgType struct {
}
type Protocol struct {
- IsASecOn abool.AtomicBool
+ IsOn abool.AtomicBool
// TODO: revision the need of the mutex
- ASecMux sync.RWMutex
- ASecCfg aSecCfgType
- JunkCreator junkCreator
+ Mux sync.RWMutex
+ Cfg Cfg
+ JunkCreator JunkCreator
MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
- protocol.ASecMux.RLock()
- defer protocol.ASecMux.RUnlock()
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
- return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0)
+ return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
- protocol.ASecMux.RLock()
- defer protocol.ASecMux.RUnlock()
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
- return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0)
+ return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
- protocol.ASecMux.RLock()
- defer protocol.ASecMux.RUnlock()
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
- return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0)
+ return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
- protocol.ASecMux.RLock()
- defer protocol.ASecMux.RUnlock()
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
- return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
+ return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
@@ -68,7 +68,6 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
return nil, nil
}
- var junk []byte
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
@@ -77,19 +76,17 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
return nil, fmt.Errorf("append junk: %w", err)
}
- junk = writer.Bytes()
-
- return junk, nil
+ return writer.Bytes(), nil
}
-func (protocol *Protocol) GetMagicHeaderMinFor(msgTypeRange uint32) (uint32, error) {
+func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, limit := range protocol.MagicHeaders.headerValues {
- if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
+ if limit.Min <= msgType && msgType <= limit.Max {
return limit.Min, nil
}
}
- return 0, fmt.Errorf("no header for range: %d", msgTypeRange)
+ return 0, fmt.Errorf("no header for value: %d", msgType)
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go
index b41441a..8ba2918 100644
--- a/device/awg/junk_creator.go
+++ b/device/awg/junk_creator.go
@@ -5,23 +5,23 @@ import (
"fmt"
)
-type junkCreator struct {
- aSecCfg aSecCfgType
+type JunkCreator struct {
+ cfg Cfg
randomGenerator PRNG[int]
}
// TODO: refactor param to only pass the junk related params
-func NewJunkCreator(aSecCfg aSecCfgType) junkCreator {
- return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()}
+func NewJunkCreator(cfg Cfg) JunkCreator {
+ return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
- if jc.aSecCfg.JunkPacketCount == 0 {
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
+ if jc.cfg.JunkPacketCount == 0 {
return
}
- for range jc.aSecCfg.JunkPacketCount {
+ for range jc.cfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk := jc.randomJunkWithSize(packetSize)
*junks = append(*junks, junk)
@@ -29,13 +29,13 @@ func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
return
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) randomPacketSize() int {
- return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize)
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) randomPacketSize() int {
+ return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk := jc.randomJunkWithSize(size)
_, err := writer.Write(headerJunk)
if err != nil {
@@ -44,7 +44,7 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
return nil
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) randomJunkWithSize(size int) []byte {
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
return jc.randomGenerator.ReadSize(size)
}
diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go
index 3553a93..33532e4 100644
--- a/device/awg/junk_creator_test.go
+++ b/device/awg/junk_creator_test.go
@@ -6,8 +6,8 @@ import (
"testing"
)
-func setUpJunkCreator() junkCreator {
- jc := NewJunkCreator(aSecCfgType{
+func setUpJunkCreator() JunkCreator {
+ jc := NewJunkCreator(Cfg{
IsSet: true,
JunkPacketCount: 5,
JunkPacketMinSize: 500,
@@ -27,7 +27,7 @@ func setUpJunkCreator() junkCreator {
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
- got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
+ got := make([][]byte, 0, jc.cfg.JunkPacketCount)
jc.CreateJunkPackets(&got)
seen := make(map[string]bool)
for _, junk := range got {
@@ -62,13 +62,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
jc := setUpJunkCreator()
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
- if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
- got > jc.aSecCfg.JunkPacketMaxSize {
+ if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
+ got > jc.cfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
- jc.aSecCfg.JunkPacketMinSize,
- jc.aSecCfg.JunkPacketMaxSize,
+ jc.cfg.JunkPacketMinSize,
+ jc.cfg.JunkPacketMaxSize,
)
}
})
diff --git a/device/device.go b/device/device.go
index 1a2b726..fe71037 100644
--- a/device/device.go
+++ b/device/device.go
@@ -593,105 +593,105 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
- if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
+ if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
var errs []error
isASecOn := false
- device.awg.ASecMux.Lock()
- if tempAwg.ASecCfg.JunkPacketCount < 0 {
+ device.awg.Mux.Lock()
+ if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
- device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
- if tempAwg.ASecCfg.JunkPacketCount != 0 {
+ device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
+ if tempAwg.Cfg.JunkPacketCount != 0 {
isASecOn = true
}
- device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
- if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
+ device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
+ if tempAwg.Cfg.JunkPacketMinSize != 0 {
isASecOn = true
}
- if device.awg.ASecCfg.JunkPacketCount > 0 &&
- tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
+ if device.awg.Cfg.JunkPacketCount > 0 &&
+ tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
- tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
+ tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
}
- if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
- device.awg.ASecCfg.JunkPacketMinSize = 0
- device.awg.ASecCfg.JunkPacketMaxSize = 1
+ if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
+ device.awg.Cfg.JunkPacketMinSize = 0
+ device.awg.Cfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
- tempAwg.ASecCfg.JunkPacketMaxSize,
+ tempAwg.Cfg.JunkPacketMaxSize,
MaxSegmentSize,
))
- } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
+ } else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
- tempAwg.ASecCfg.JunkPacketMaxSize,
- tempAwg.ASecCfg.JunkPacketMinSize,
+ tempAwg.Cfg.JunkPacketMaxSize,
+ tempAwg.Cfg.JunkPacketMinSize,
))
} else {
- device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
+ device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
}
- if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
+ if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isASecOn = true
}
limits := make([]awg.MagicHeader, 4)
- if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 {
+ if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
- limits[0] = tempAwg.ASecCfg.InitPacketMagicHeader
- MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min
+ device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader
+ limits[0] = tempAwg.Cfg.InitPacketMagicHeader
+ MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
- if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 {
+ if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
- MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min
- limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader
+ device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader
+ MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min
+ limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
- if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 {
+ if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
- MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min
- limits[2] = tempAwg.ASecCfg.UnderloadPacketMagicHeader
+ device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader
+ MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min
+ limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
- if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 {
+ if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
- MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min
- limits[3] = tempAwg.ASecCfg.TransportPacketMagicHeader
+ device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader
+ MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min
+ limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@@ -724,75 +724,75 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
- newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
+ newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.InitHeaderJunkSize,
+ tempAwg.Cfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
- device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
+ device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
}
- if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
+ if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isASecOn = true
}
- newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
+ newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.ResponseHeaderJunkSize,
+ tempAwg.Cfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
- device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
+ device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
}
- if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
+ if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
}
- newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
+ newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
+ tempAwg.Cfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
- device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
+ device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
}
- if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
+ if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
- newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
+ newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.TransportHeaderJunkSize,
+ tempAwg.Cfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
- device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
+ device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
}
- if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
+ if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
@@ -815,10 +815,10 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
} else {
msgTypeToJunkSize = map[uint32]int{
- MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
- MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
- MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
- MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
+ MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
+ MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
+ MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
+ MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
}
packetSizeToMsgType = map[int]uint32{
@@ -829,8 +829,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
}
- device.awg.IsASecOn.SetTo(isASecOn)
- device.awg.JunkCreator = awg.NewJunkCreator(device.awg.ASecCfg)
+ device.awg.IsOn.SetTo(isASecOn)
+ device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
@@ -838,78 +838,89 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
- device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
- device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
+ device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
+ device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
- device.awg.ASecMux.Unlock()
+ device.awg.Mux.Unlock()
return errors.Join(errs...)
}
-func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) {
+func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
- assumedMsgType, hasAssumedType := packetSizeToMsgType[size]
- if !hasAssumedType {
- return device.handleTransport(size, packet, bufsArrs)
- }
+ expectedMsgType, isKnownSize := packetSizeToMsgType[size]
+ if !isKnownSize {
+ msgType, err := device.handleTransport(size, packet, buffer)
- junkSize := msgTypeToJunkSize[assumedMsgType]
+ if err != nil {
+ return 0, fmt.Errorf("handle transport: %w", err)
+ }
- // transport size can align with other header types;
- // making sure we have the right msgType
- msgType, err = device.getMsgType(packet, junkSize)
- if err != nil {
- return 0, fmt.Errorf("aSec: get msg type: %w", err)
- }
-
- if msgType == assumedMsgType {
- *packet = (*packet)[junkSize:]
return msgType, nil
}
- device.log.Verbosef("transport packet lined up with another msg type")
-
- return device.handleTransport(size, packet, bufsArrs)
-}
-
-func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
- msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
- msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeRange)
+ junkSize := msgTypeToJunkSize[expectedMsgType]
+ // transport size can align with other header types;
+ // making sure we have the right actualMsgType
+ actualMsgType, err := device.getMsgType(packet, junkSize)
if err != nil {
- return 0, fmt.Errorf("get limit min: %w", err)
+ return 0, fmt.Errorf("get msg type: %w", err)
+ }
+
+ if actualMsgType == expectedMsgType {
+ *packet = (*packet)[junkSize:]
+ return actualMsgType, nil
+ }
+
+ device.log.Verbosef("awg: transport packet lined up with another msg type")
+
+ msgType, err := device.handleTransport(size, packet, buffer)
+ if err != nil {
+ return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
-func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (uint32, error) {
- transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
+func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
+ msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
+ msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
- msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize)
+ if err != nil {
+ return 0, fmt.Errorf("get magic header min: %w", err)
+ }
+
+ return msgType, nil
+}
+
+func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
+ junkSize := device.awg.Cfg.TransportHeaderJunkSize
+
+ msgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if msgType != MessageTransportType {
// probably a junk packet
- return 0, fmt.Errorf("aSec: Received message with unknown type: %d", msgType)
+ return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
}
- if transportJunkSize > 0 {
- // remove junk from bufsArrs by shifting the packet
+ if junkSize > 0 {
+ // remove junk from buffer by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
- copy((*bufsArrs)[:size], (*packet)[transportJunkSize:])
- size -= transportJunkSize
+ copy((*buffer)[:size], (*packet)[junkSize:])
+ size -= junkSize
// need to reinitialize packet as well
(*packet) = (*packet)[:size]
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index bce2d2f..bac089d 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -205,10 +205,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
@@ -216,7 +216,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
Type: msgType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -270,13 +270,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
if msg.Type != MessageInitiationType {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -391,14 +391,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -448,13 +448,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
if msg.Type != MessageResponseType {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index 41ffa2b..8ace3a6 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -140,9 +140,10 @@ func (device *Device) RoutineReceiveIncoming(
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAWG() {
- msgType, err = device.Logic(size, &packet, bufsArrs[i])
+ msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
+
if err != nil {
- device.log.Verbosef("awg device logic: %v", err)
+ device.log.Verbosef("awg: process packet: %v", err)
continue
}
} else {
@@ -232,7 +233,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -291,7 +292,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
// handle cookie fields and ratelimiting
@@ -449,7 +450,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index b87dda2..296e9a8 100644
--- a/device/send.go
+++ b/device/send.go
@@ -130,7 +130,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
- peer.device.awg.ASecMux.RLock()
+ peer.device.awg.Mux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
@@ -141,13 +141,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
- peer.device.awg.ASecMux.RUnlock()
+ peer.device.awg.Mux.RUnlock()
} else {
- junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
+ junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
}
- peer.device.awg.ASecMux.RLock()
+ peer.device.awg.Mux.RLock()
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
- peer.device.awg.ASecMux.RUnlock()
+ peer.device.awg.Mux.RUnlock()
if len(junks) > 0 {
err = peer.SendBuffers(junks)
diff --git a/device/uapi.go b/device/uapi.go
index d99a4ee..f4cb834 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -99,26 +99,26 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAWG() {
- if device.awg.ASecCfg.JunkPacketCount != 0 {
- sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
+ if device.awg.Cfg.JunkPacketCount != 0 {
+ sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
}
- if device.awg.ASecCfg.JunkPacketMinSize != 0 {
- sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
+ if device.awg.Cfg.JunkPacketMinSize != 0 {
+ sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
}
- if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
- sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
+ if device.awg.Cfg.JunkPacketMaxSize != 0 {
+ sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
}
- if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
- sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
+ if device.awg.Cfg.InitHeaderJunkSize != 0 {
+ sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
}
- if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
- sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
+ if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
+ sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
}
- if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
- sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
+ if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
+ sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
}
- if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
- sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
+ if device.awg.Cfg.TransportHeaderJunkSize != 0 {
+ sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
// TODO:
// if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
@@ -313,8 +313,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
- tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketCount = junkPacketCount
+ tempAwg.Cfg.IsSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@@ -322,8 +322,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
- tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
+ tempAwg.Cfg.IsSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@@ -331,8 +331,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
- tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
+ tempAwg.Cfg.IsSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@@ -340,8 +340,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
- tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@@ -349,8 +349,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
- tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
@@ -358,8 +358,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
- tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
@@ -367,8 +367,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
- tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value)
@@ -376,8 +376,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
@@ -385,8 +385,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
@@ -394,8 +394,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
@@ -403,8 +403,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader
+ tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {