chore: restructure code and finish impl

This commit is contained in:
Mark Puha 2025-07-10 20:25:16 +02:00
parent 6992e18755
commit 699bd240cc
5 changed files with 79 additions and 79 deletions

View file

@ -18,19 +18,15 @@ type Cfg struct {
CookieReplyHeaderJunkSize int CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int TransportHeaderJunkSize int
InitPacketMagicHeader MagicHeader MagicHeaders MagicHeaders
ResponsePacketMagicHeader MagicHeader
UnderloadPacketMagicHeader MagicHeader
TransportPacketMagicHeader MagicHeader
} }
type Protocol struct { type Protocol struct {
IsOn abool.AtomicBool IsOn abool.AtomicBool
// TODO: revision the need of the mutex // TODO: revision the need of the mutex
Mux sync.RWMutex Mux sync.RWMutex
Cfg Cfg Cfg Cfg
JunkCreator JunkCreator JunkCreator JunkCreator
MagicHeaders MagicHeaders
HandshakeHandler SpecialHandshakeHandler HandshakeHandler SpecialHandshakeHandler
} }
@ -80,9 +76,9 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
} }
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) { func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, limit := range protocol.MagicHeaders.headerValues { for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
if limit.Min <= msgType && msgType <= limit.Max { if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
return limit.Min, nil return magicHeader.Min, nil
} }
} }
@ -90,5 +86,5 @@ func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
} }
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) { func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.MagicHeaders.Get(defaultMsgType) return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
} }

View file

@ -58,7 +58,7 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
} }
type MagicHeaders struct { type MagicHeaders struct {
headerValues []MagicHeader Values []MagicHeader
randomGenerator RandomNumberGenerator[uint32] randomGenerator RandomNumberGenerator[uint32]
} }
@ -81,7 +81,7 @@ func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
} }
} }
return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
} }
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
@ -89,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType) return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
} }
return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
} }

View file

@ -378,7 +378,7 @@ func TestNewMagicHeaders(t *testing.T) {
require.Equal(t, MagicHeaders{}, result) require.Equal(t, MagicHeaders{}, result)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tt.magicHeaders, result.headerValues) require.Equal(t, tt.magicHeaders, result.Values)
require.NotNil(t, result.randomGenerator) require.NotNil(t, result.randomGenerator)
} }
}) })
@ -469,7 +469,7 @@ func TestMagicHeaders_Get(t *testing.T) {
t.Parallel() t.Parallel()
// Create a new instance with mock PRNG for each test // Create a new instance with mock PRNG for each test
testMagicHeaders := MagicHeaders{ testMagicHeaders := MagicHeaders{
headerValues: headers, Values: headers,
randomGenerator: &mockPRNG{returnValue: tt.mockValue}, randomGenerator: &mockPRNG{returnValue: tt.mockValue},
} }

View file

@ -580,6 +580,7 @@ func (device *Device) BindClose() error {
device.net.Unlock() device.net.Unlock()
return err return err
} }
func (device *Device) isAWG() bool { func (device *Device) isAWG() bool {
return device.version >= VersionAwg return device.version >= VersionAwg
} }
@ -599,7 +600,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
var errs []error var errs []error
isASecOn := false isAwgOn := false
device.awg.Mux.Lock() device.awg.Mux.Lock()
if tempAwg.Cfg.JunkPacketCount < 0 { if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
@ -610,12 +611,12 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.Cfg.JunkPacketCount != 0 { if tempAwg.Cfg.JunkPacketCount != 0 {
isASecOn = true isAwgOn = true
} }
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.Cfg.JunkPacketMinSize != 0 { if tempAwg.Cfg.JunkPacketMinSize != 0 {
isASecOn = true isAwgOn = true
} }
if device.awg.Cfg.JunkPacketCount > 0 && if device.awg.Cfg.JunkPacketCount > 0 &&
@ -645,57 +646,72 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.Cfg.JunkPacketMaxSize != 0 { if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isASecOn = true isAwgOn = true
} }
limits := make([]awg.MagicHeader, 4) magicHeaders := make([]awg.MagicHeader, 4)
if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 { if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
isASecOn = true return ipcErrorf(
ipc.IpcErrorInvalid,
"magic headers should have 4 values; got: %d",
len(tempAwg.Cfg.MagicHeaders.Values),
)
}
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header") device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
limits[0] = tempAwg.Cfg.InitPacketMagicHeader
MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min MessageInitiationType = magicHeaders[0].Min
} else { } else {
device.log.Verbosef("UAPI: Using default init type") device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType MessageInitiationType = DefaultMessageInitiationType
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType) magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
} }
if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 { if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header") device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min MessageResponseType = magicHeaders[1].Min
limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default response type") device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType MessageResponseType = DefaultMessageResponseType
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType) magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
} }
if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 { if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header") device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min MessageCookieReplyType = magicHeaders[2].Min
limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default underload type") device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType MessageCookieReplyType = DefaultMessageCookieReplyType
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType) magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
} }
if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 { if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min MessageTransportType = magicHeaders[3].Min
limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default transport type") device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType MessageTransportType = DefaultMessageTransportType
limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType) magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
var err error
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
} }
isSameHeaderMap := map[uint32]struct{}{ isSameHeaderMap := map[uint32]struct{}{
@ -705,12 +721,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType: {}, MessageTransportType: {},
} }
var err error
device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
// size will be different if same values // size will be different if same values
if len(isSameHeaderMap) != 4 { if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
@ -739,7 +749,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.Cfg.InitHeaderJunkSize != 0 { if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isASecOn = true isAwgOn = true
} }
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
@ -757,7 +767,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 { if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isASecOn = true isAwgOn = true
} }
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
@ -775,7 +785,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 { if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true isAwgOn = true
} }
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
@ -793,7 +803,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
if tempAwg.Cfg.TransportHeaderJunkSize != 0 { if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isASecOn = true isAwgOn = true
} }
isSameSizeMap := map[int]struct{}{ isSameSizeMap := map[int]struct{}{
@ -829,7 +839,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
} }
device.awg.IsOn.SetTo(isASecOn) device.awg.IsOn.SetTo(isAwgOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg) device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet { if tempAwg.HandshakeHandler.IsSet {

View file

@ -120,19 +120,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.Cfg.TransportHeaderJunkSize != 0 { if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize) sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
} }
// TODO: for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
// if device.awg.ASecCfg.InitPacketMagicHeader != 0 { if magicHeader.Min > 4 {
// sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) if magicHeader.Min == magicHeader.Max {
// } sendf("h%d=%d", i+1, magicHeader.Min)
// if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { continue
// sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) }
// }
// if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
// sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) }
// } }
// if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
// }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields { for _, field := range specialJunkIpcFields {
@ -201,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
deviceConfig := true deviceConfig := true
tempAwg := awg.Protocol{} tempAwg := awg.Protocol{}
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -369,43 +368,38 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
device.log.Verbosef("UAPI: Updating transport_packet_junk_size") device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.Cfg.IsSet = true tempAwg.Cfg.IsSet = true
case "h1": case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value) initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
tempAwg.Cfg.IsSet = true tempAwg.Cfg.IsSet = true
case "h2": case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value) responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
tempAwg.Cfg.IsSet = true tempAwg.Cfg.IsSet = true
case "h3": case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value) cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true tempAwg.Cfg.IsSet = true
case "h4": case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value) transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
tempAwg.Cfg.IsSet = true tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5": case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 { if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key) device.log.Verbosef("UAPI: received empty %s", key)