From 699bd240cc44d1a106ad3d2591df227c8fa0a8cd Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 10 Jul 2025 20:25:16 +0200
Subject: [PATCH] chore: restructure code and finish impl
---
device/awg/awg.go | 20 +++-----
device/awg/magic_header.go | 6 +--
device/awg/magic_header_test.go | 4 +-
device/device.go | 90 ++++++++++++++++++---------------
device/uapi.go | 38 ++++++--------
5 files changed, 79 insertions(+), 79 deletions(-)
diff --git a/device/awg/awg.go b/device/awg/awg.go
index 29a4b62..888a42e 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -18,19 +18,15 @@ type Cfg struct {
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
- InitPacketMagicHeader MagicHeader
- ResponsePacketMagicHeader MagicHeader
- UnderloadPacketMagicHeader MagicHeader
- TransportPacketMagicHeader MagicHeader
+ MagicHeaders MagicHeaders
}
type Protocol struct {
IsOn abool.AtomicBool
// TODO: revision the need of the mutex
- Mux sync.RWMutex
- Cfg Cfg
- JunkCreator JunkCreator
- MagicHeaders MagicHeaders
+ Mux sync.RWMutex
+ Cfg Cfg
+ JunkCreator JunkCreator
HandshakeHandler SpecialHandshakeHandler
}
@@ -80,9 +76,9 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
}
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
- for _, limit := range protocol.MagicHeaders.headerValues {
- if limit.Min <= msgType && msgType <= limit.Max {
- return limit.Min, nil
+ for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
+ if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
+ return magicHeader.Min, nil
}
}
@@ -90,5 +86,5 @@ func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
- return protocol.MagicHeaders.Get(defaultMsgType)
+ return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
}
diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go
index df8bdf1..b8bfb2f 100644
--- a/device/awg/magic_header.go
+++ b/device/awg/magic_header.go
@@ -58,7 +58,7 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
}
type MagicHeaders struct {
- headerValues []MagicHeader
+ Values []MagicHeader
randomGenerator RandomNumberGenerator[uint32]
}
@@ -81,7 +81,7 @@ func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
}
}
- return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
+ return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
@@ -89,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
- return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil
+ return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
}
diff --git a/device/awg/magic_header_test.go b/device/awg/magic_header_test.go
index d9e608c..72a823e 100644
--- a/device/awg/magic_header_test.go
+++ b/device/awg/magic_header_test.go
@@ -378,7 +378,7 @@ func TestNewMagicHeaders(t *testing.T) {
require.Equal(t, MagicHeaders{}, result)
} else {
require.NoError(t, err)
- require.Equal(t, tt.magicHeaders, result.headerValues)
+ require.Equal(t, tt.magicHeaders, result.Values)
require.NotNil(t, result.randomGenerator)
}
})
@@ -469,7 +469,7 @@ func TestMagicHeaders_Get(t *testing.T) {
t.Parallel()
// Create a new instance with mock PRNG for each test
testMagicHeaders := MagicHeaders{
- headerValues: headers,
+ Values: headers,
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
}
diff --git a/device/device.go b/device/device.go
index fe71037..6853e45 100644
--- a/device/device.go
+++ b/device/device.go
@@ -580,6 +580,7 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
+
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
}
@@ -599,7 +600,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
var errs []error
- isASecOn := false
+ isAwgOn := false
device.awg.Mux.Lock()
if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
@@ -610,12 +611,12 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.Cfg.JunkPacketCount != 0 {
- isASecOn = true
+ isAwgOn = true
}
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.Cfg.JunkPacketMinSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
if device.awg.Cfg.JunkPacketCount > 0 &&
@@ -645,57 +646,72 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
- limits := make([]awg.MagicHeader, 4)
+ magicHeaders := make([]awg.MagicHeader, 4)
- if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 {
- isASecOn = true
+ if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
+ return ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "magic headers should have 4 values; got: %d",
+ len(tempAwg.Cfg.MagicHeaders.Values),
+ )
+ }
+
+
+ if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
+ isAwgOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader
- limits[0] = tempAwg.Cfg.InitPacketMagicHeader
- MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min
+ magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
+
+ MessageInitiationType = magicHeaders[0].Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
- limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
+ magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
- if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader
- MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min
- limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
+ magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
+ MessageResponseType = magicHeaders[1].Min
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
- limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
+ magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
- if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader
- MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min
- limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
+ magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
+ MessageCookieReplyType = magicHeaders[2].Min
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
- limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
+ magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
- if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader
- MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min
- limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
+ magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
+ MessageTransportType = magicHeaders[3].Min
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
- limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
+ magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
+ }
+
+ var err error
+ device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
+ if err != nil {
+ errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
isSameHeaderMap := map[uint32]struct{}{
@@ -705,12 +721,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {},
}
- var err error
- device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
- if err != nil {
- errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
- }
-
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
@@ -739,7 +749,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
@@ -757,7 +767,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
@@ -775,7 +785,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
@@ -793,7 +803,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
- isASecOn = true
+ isAwgOn = true
}
isSameSizeMap := map[int]struct{}{
@@ -829,7 +839,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
}
- device.awg.IsOn.SetTo(isASecOn)
+ device.awg.IsOn.SetTo(isAwgOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {
diff --git a/device/uapi.go b/device/uapi.go
index f4cb834..6a0cdc2 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -120,19 +120,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
- // TODO:
- // if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
- // sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
- // }
- // if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
- // sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
- // }
- // if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
- // sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
- // }
- // if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
- // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
- // }
+ for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
+ if magicHeader.Min > 4 {
+ if magicHeader.Min == magicHeader.Max {
+ sendf("h%d=%d", i+1, magicHeader.Min)
+ continue
+ }
+
+ sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
+ }
+ }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
@@ -201,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
deviceConfig := true
tempAwg := awg.Protocol{}
+ tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
+
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@@ -369,43 +368,38 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.Cfg.IsSet = true
-
case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader
+ tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
tempAwg.Cfg.IsSet = true
-
case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader
+ tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
tempAwg.Cfg.IsSet = true
-
case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
+ tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true
-
case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader
+ tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
tempAwg.Cfg.IsSet = true
-
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)