From 6b0bbcc75cb88c1cf1370ca40b8907e378b205e8 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:36:53 +0100
Subject: [PATCH] migrate awg specific code to struct
---
device/device.go | 65 ++++++++++++++++++++-----------------
device/junk_creator.go | 10 +++---
device/junk_creator_test.go | 8 ++---
device/noise-protocol.go | 20 ++++++------
device/receive.go | 12 +++----
device/send.go | 34 +++++++++----------
device/uapi.go | 36 ++++++++++----------
7 files changed, 95 insertions(+), 90 deletions(-)
diff --git a/device/device.go b/device/device.go
index aaddcbc..f375663 100644
--- a/device/device.go
+++ b/device/device.go
@@ -92,14 +92,17 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
+
+ awg awgType
+}
+type awgType struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
luaAdapter *adapter.Lua
- packetCounter atomic.Int64
}
type aSecCfgType struct {
@@ -431,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
- if device.luaAdapter != nil {
- device.luaAdapter.Close()
- device.luaAdapter = nil
+ if device.awg.luaAdapter != nil {
+ device.awg.luaAdapter.Close()
+ device.awg.luaAdapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -581,7 +584,7 @@ func (device *Device) BindClose() error {
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
- return device.isASecOn.IsSet()
+ return device.awg.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
@@ -594,36 +597,36 @@ func (device *Device) resetProtocol() {
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
- return err
+ return nil
}
isASecOn := false
- device.aSecMux.Lock()
+ device.awg.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
- device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
+ device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
- device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
+ device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
- if device.aSecCfg.junkPacketCount > 0 &&
+ if device.awg.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
- device.aSecCfg.junkPacketMinSize = 0
- device.aSecCfg.junkPacketMaxSize = 1
+ device.awg.aSecCfg.junkPacketMinSize = 0
+ device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@@ -658,7 +661,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
+ device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
@@ -683,7 +686,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
+ device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
@@ -708,7 +711,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
+ device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
@@ -718,8 +721,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
- MessageInitiationType = device.aSecCfg.initPacketMagicHeader
+ device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
+ MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
@@ -728,8 +731,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
- MessageResponseType = device.aSecCfg.responsePacketMagicHeader
+ device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
+ MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
@@ -738,8 +741,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
- MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
+ device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
+ MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
@@ -748,8 +751,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
- MessageTransportType = device.aSecCfg.transportPacketMagicHeader
+ device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
+ MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@@ -785,8 +788,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
}
- newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
- newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
+ newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
+ newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
@@ -814,16 +817,18 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
msgTypeToJunkSize = map[uint32]int{
- MessageInitiationType: device.aSecCfg.initPacketJunkSize,
- MessageResponseType: device.aSecCfg.responsePacketJunkSize,
+ MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
+ MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
- device.isASecOn.SetTo(isASecOn)
- device.junkCreator, err = NewJunkCreator(device)
- device.aSecMux.Unlock()
+ device.awg.isASecOn.SetTo(isASecOn)
+ if device.awg.isASecOn.IsSet() {
+ device.awg.junkCreator, err = NewJunkCreator(device)
+ }
+ device.awg.aSecMux.Unlock()
return err
}
diff --git a/device/junk_creator.go b/device/junk_creator.go
index 85a5bbc..2bfb197 100644
--- a/device/junk_creator.go
+++ b/device/junk_creator.go
@@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
- if jc.device.aSecCfg.junkPacketCount == 0 {
+ if jc.device.awg.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
- junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
- for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
+ junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount)
+ for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
@@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
- jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
+ jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
),
- ) + jc.device.aSecCfg.junkPacketMinSize
+ ) + jc.device.awg.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go
index 6f63360..96fa6d3 100644
--- a/device/junk_creator_test.go
+++ b/device/junk_creator_test.go
@@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
- if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
- got > jc.device.aSecCfg.junkPacketMaxSize {
+ if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
+ got > jc.device.awg.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
- jc.device.aSecCfg.junkPacketMinSize,
- jc.device.aSecCfg.junkPacketMaxSize,
+ jc.device.awg.aSecCfg.junkPacketMinSize,
+ jc.device.awg.aSecCfg.junkPacketMaxSize,
)
}
})
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index d818238..2dc0d93 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
if msg.Type != MessageInitiationType {
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
return nil
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
msg.Type = MessageResponseType
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
if msg.Type != MessageResponseType {
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
return nil
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index e790048..70801ef 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -138,8 +138,8 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet
packet := bufsArrs[i][:size]
- if device.luaAdapter != nil {
- packet, err = device.luaAdapter.Parse(packet)
+ if device.awg.luaAdapter != nil {
+ packet, err = device.awg.luaAdapter.Parse(packet)
if err != nil {
device.log.Verbosef("Couldn't parse message; reason: %v", err)
continue
@@ -251,7 +251,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -310,7 +310,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
// handle cookie fields and ratelimiting
@@ -462,7 +462,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index ac05313..297f19e 100644
--- a/device/send.go
+++ b/device/send.go
@@ -131,9 +131,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.aSecMux.RLock()
- junks, err := peer.device.junkCreator.createJunkPackets(peer)
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RLock()
+ junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
+ peer.device.awg.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@@ -153,19 +153,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
- peer.device.aSecMux.RLock()
- if peer.device.aSecCfg.initPacketJunkSize != 0 {
- buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
+ peer.device.awg.aSecMux.RLock()
+ if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
+ err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
@@ -215,19 +215,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.aSecMux.RLock()
- if peer.device.aSecCfg.responsePacketJunkSize != 0 {
- buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
+ peer.device.awg.aSecMux.RLock()
+ if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
+ err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@@ -548,9 +548,9 @@ func calculatePaddingSize(packetSize, mtu int) int {
}
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
- if device.luaAdapter != nil {
+ if device.awg.luaAdapter != nil {
var err error
- packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
+ packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
diff --git a/device/uapi.go b/device/uapi.go
index 777bdda..1bce39b 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -98,32 +98,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAdvancedSecurityOn() {
- if device.aSecCfg.junkPacketCount != 0 {
- sendf("jc=%d", device.aSecCfg.junkPacketCount)
+ if device.awg.aSecCfg.junkPacketCount != 0 {
+ sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
}
- if device.aSecCfg.junkPacketMinSize != 0 {
- sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
+ if device.awg.aSecCfg.junkPacketMinSize != 0 {
+ sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
}
- if device.aSecCfg.junkPacketMaxSize != 0 {
- sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
+ if device.awg.aSecCfg.junkPacketMaxSize != 0 {
+ sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
}
- if device.aSecCfg.initPacketJunkSize != 0 {
- sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
+ if device.awg.aSecCfg.initPacketJunkSize != 0 {
+ sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
}
- if device.aSecCfg.responsePacketJunkSize != 0 {
- sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
+ if device.awg.aSecCfg.responsePacketJunkSize != 0 {
+ sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
}
- if device.aSecCfg.initPacketMagicHeader != 0 {
- sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
+ if device.awg.aSecCfg.initPacketMagicHeader != 0 {
+ sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
}
- if device.aSecCfg.responsePacketMagicHeader != 0 {
- sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
+ if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
+ sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
}
- if device.aSecCfg.underloadPacketMagicHeader != 0 {
- sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
+ if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
+ sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
}
- if device.aSecCfg.transportPacketMagicHeader != 0 {
- sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
+ if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
+ sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
}
}