chore: restructure code and finish impl

This commit is contained in:
Mark Puha 2025-07-10 20:25:16 +02:00
parent 6992e18755
commit 699bd240cc
5 changed files with 79 additions and 79 deletions

View file

@ -18,19 +18,15 @@ type Cfg struct {
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
InitPacketMagicHeader MagicHeader
ResponsePacketMagicHeader MagicHeader
UnderloadPacketMagicHeader MagicHeader
TransportPacketMagicHeader MagicHeader
MagicHeaders MagicHeaders
}
type Protocol struct {
IsOn abool.AtomicBool
// TODO: revision the need of the mutex
Mux sync.RWMutex
Cfg Cfg
JunkCreator JunkCreator
MagicHeaders MagicHeaders
Mux sync.RWMutex
Cfg Cfg
JunkCreator JunkCreator
HandshakeHandler SpecialHandshakeHandler
}
@ -80,9 +76,9 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
}
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, limit := range protocol.MagicHeaders.headerValues {
if limit.Min <= msgType && msgType <= limit.Max {
return limit.Min, nil
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
return magicHeader.Min, nil
}
}
@ -90,5 +86,5 @@ func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.MagicHeaders.Get(defaultMsgType)
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
}

View file

@ -58,7 +58,7 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
}
type MagicHeaders struct {
headerValues []MagicHeader
Values []MagicHeader
randomGenerator RandomNumberGenerator[uint32]
}
@ -81,7 +81,7 @@ func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
}
}
return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
@ -89,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
}

View file

@ -378,7 +378,7 @@ func TestNewMagicHeaders(t *testing.T) {
require.Equal(t, MagicHeaders{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.magicHeaders, result.headerValues)
require.Equal(t, tt.magicHeaders, result.Values)
require.NotNil(t, result.randomGenerator)
}
})
@ -469,7 +469,7 @@ func TestMagicHeaders_Get(t *testing.T) {
t.Parallel()
// Create a new instance with mock PRNG for each test
testMagicHeaders := MagicHeaders{
headerValues: headers,
Values: headers,
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
}

View file

@ -580,6 +580,7 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
}
@ -599,7 +600,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
var errs []error
isASecOn := false
isAwgOn := false
device.awg.Mux.Lock()
if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
@ -610,12 +611,12 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.Cfg.JunkPacketCount != 0 {
isASecOn = true
isAwgOn = true
}
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.Cfg.JunkPacketMinSize != 0 {
isASecOn = true
isAwgOn = true
}
if device.awg.Cfg.JunkPacketCount > 0 &&
@ -645,57 +646,72 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isASecOn = true
isAwgOn = true
}
limits := make([]awg.MagicHeader, 4)
magicHeaders := make([]awg.MagicHeader, 4)
if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 {
isASecOn = true
if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
return ipcErrorf(
ipc.IpcErrorInvalid,
"magic headers should have 4 values; got: %d",
len(tempAwg.Cfg.MagicHeaders.Values),
)
}
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader
limits[0] = tempAwg.Cfg.InitPacketMagicHeader
MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min
magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
MessageInitiationType = magicHeaders[0].Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 {
isASecOn = true
if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min
limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
MessageResponseType = magicHeaders[1].Min
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 {
isASecOn = true
if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min
limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
MessageCookieReplyType = magicHeaders[2].Min
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 {
isASecOn = true
if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader
MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min
limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
MessageTransportType = magicHeaders[3].Min
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
var err error
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
isSameHeaderMap := map[uint32]struct{}{
@ -705,12 +721,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {},
}
var err error
device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
@ -739,7 +749,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isASecOn = true
isAwgOn = true
}
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
@ -757,7 +767,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
isAwgOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
@ -775,7 +785,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
isAwgOn = true
}
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
@ -793,7 +803,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isASecOn = true
isAwgOn = true
}
isSameSizeMap := map[int]struct{}{
@ -829,7 +839,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
}
device.awg.IsOn.SetTo(isASecOn)
device.awg.IsOn.SetTo(isAwgOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {

View file

@ -120,19 +120,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
// TODO:
// if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
// sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
// }
// if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
// sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
// }
// if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
// sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
// }
// if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
// }
for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
if magicHeader.Min > 4 {
if magicHeader.Min == magicHeader.Max {
sendf("h%d=%d", i+1, magicHeader.Min)
continue
}
sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
}
}
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
@ -201,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
deviceConfig := true
tempAwg := awg.Protocol{}
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@ -369,43 +368,38 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.Cfg.IsSet = true
case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
tempAwg.Cfg.IsSet = true
case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
tempAwg.Cfg.IsSet = true
case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true
case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader
tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)