diff --git a/device/awg/awg.go b/device/awg/awg.go index 29a4b62..888a42e 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -18,19 +18,15 @@ type Cfg struct { CookieReplyHeaderJunkSize int TransportHeaderJunkSize int - InitPacketMagicHeader MagicHeader - ResponsePacketMagicHeader MagicHeader - UnderloadPacketMagicHeader MagicHeader - TransportPacketMagicHeader MagicHeader + MagicHeaders MagicHeaders } type Protocol struct { IsOn abool.AtomicBool // TODO: revision the need of the mutex - Mux sync.RWMutex - Cfg Cfg - JunkCreator JunkCreator - MagicHeaders MagicHeaders + Mux sync.RWMutex + Cfg Cfg + JunkCreator JunkCreator HandshakeHandler SpecialHandshakeHandler } @@ -80,9 +76,9 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, } func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) { - for _, limit := range protocol.MagicHeaders.headerValues { - if limit.Min <= msgType && msgType <= limit.Max { - return limit.Min, nil + for _, magicHeader := range protocol.Cfg.MagicHeaders.Values { + if magicHeader.Min <= msgType && msgType <= magicHeader.Max { + return magicHeader.Min, nil } } @@ -90,5 +86,5 @@ func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) { } func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) { - return protocol.MagicHeaders.Get(defaultMsgType) + return protocol.Cfg.MagicHeaders.Get(defaultMsgType) } diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go index df8bdf1..b8bfb2f 100644 --- a/device/awg/magic_header.go +++ b/device/awg/magic_header.go @@ -58,7 +58,7 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) { } type MagicHeaders struct { - headerValues []MagicHeader + Values []MagicHeader randomGenerator RandomNumberGenerator[uint32] } @@ -81,7 +81,7 @@ func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) { } } - return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil + return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil } func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { @@ -89,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType) } - return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil + return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil } diff --git a/device/awg/magic_header_test.go b/device/awg/magic_header_test.go index d9e608c..72a823e 100644 --- a/device/awg/magic_header_test.go +++ b/device/awg/magic_header_test.go @@ -378,7 +378,7 @@ func TestNewMagicHeaders(t *testing.T) { require.Equal(t, MagicHeaders{}, result) } else { require.NoError(t, err) - require.Equal(t, tt.magicHeaders, result.headerValues) + require.Equal(t, tt.magicHeaders, result.Values) require.NotNil(t, result.randomGenerator) } }) @@ -469,7 +469,7 @@ func TestMagicHeaders_Get(t *testing.T) { t.Parallel() // Create a new instance with mock PRNG for each test testMagicHeaders := MagicHeaders{ - headerValues: headers, + Values: headers, randomGenerator: &mockPRNG{returnValue: tt.mockValue}, } diff --git a/device/device.go b/device/device.go index fe71037..6853e45 100644 --- a/device/device.go +++ b/device/device.go @@ -580,6 +580,7 @@ func (device *Device) BindClose() error { device.net.Unlock() return err } + func (device *Device) isAWG() bool { return device.version >= VersionAwg } @@ -599,7 +600,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { var errs []error - isASecOn := false + isAwgOn := false device.awg.Mux.Lock() if tempAwg.Cfg.JunkPacketCount < 0 { errs = append(errs, ipcErrorf( @@ -610,12 +611,12 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount if tempAwg.Cfg.JunkPacketCount != 0 { - isASecOn = true + isAwgOn = true } device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize if tempAwg.Cfg.JunkPacketMinSize != 0 { - isASecOn = true + isAwgOn = true } if device.awg.Cfg.JunkPacketCount > 0 && @@ -645,57 +646,72 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.Cfg.JunkPacketMaxSize != 0 { - isASecOn = true + isAwgOn = true } - limits := make([]awg.MagicHeader, 4) + magicHeaders := make([]awg.MagicHeader, 4) - if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 { - isASecOn = true + if len(tempAwg.Cfg.MagicHeaders.Values) != 4 { + return ipcErrorf( + ipc.IpcErrorInvalid, + "magic headers should have 4 values; got: %d", + len(tempAwg.Cfg.MagicHeaders.Values), + ) + } + + + if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 { + isAwgOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader - limits[0] = tempAwg.Cfg.InitPacketMagicHeader - MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min + magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0] + + MessageInitiationType = magicHeaders[0].Min } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType - limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType) + magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType) } - if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 { - isASecOn = true + if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 { + isAwgOn = true + device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader - MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min - limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader + magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1] + MessageResponseType = magicHeaders[1].Min } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType - limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType) + magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType) } - if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 { - isASecOn = true + if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 { + isAwgOn = true + device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader - MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min - limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader + magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2] + MessageCookieReplyType = magicHeaders[2].Min } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType - limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType) + magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType) } - if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 { - isASecOn = true + if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 { + isAwgOn = true + device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader - MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min - limits[3] = tempAwg.Cfg.TransportPacketMagicHeader + magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3] + MessageTransportType = magicHeaders[3].Min } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType - limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType) + magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType) + } + + var err error + device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders) + if err != nil { + errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err)) } isSameHeaderMap := map[uint32]struct{}{ @@ -705,12 +721,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { MessageTransportType: {}, } - var err error - device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits) - if err != nil { - errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err)) - } - // size will be different if same values if len(isSameHeaderMap) != 4 { errs = append(errs, ipcErrorf( @@ -739,7 +749,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.Cfg.InitHeaderJunkSize != 0 { - isASecOn = true + isAwgOn = true } newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize @@ -757,7 +767,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.Cfg.ResponseHeaderJunkSize != 0 { - isASecOn = true + isAwgOn = true } newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize @@ -775,7 +785,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 { - isASecOn = true + isAwgOn = true } newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize @@ -793,7 +803,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } if tempAwg.Cfg.TransportHeaderJunkSize != 0 { - isASecOn = true + isAwgOn = true } isSameSizeMap := map[int]struct{}{ @@ -829,7 +839,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } } - device.awg.IsOn.SetTo(isASecOn) + device.awg.IsOn.SetTo(isAwgOn) device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg) if tempAwg.HandshakeHandler.IsSet { diff --git a/device/uapi.go b/device/uapi.go index f4cb834..6a0cdc2 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -120,19 +120,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error { if device.awg.Cfg.TransportHeaderJunkSize != 0 { sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize) } - // TODO: - // if device.awg.ASecCfg.InitPacketMagicHeader != 0 { - // sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) - // } - // if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { - // sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) - // } - // if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { - // sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) - // } - // if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { - // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) - // } + for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values { + if magicHeader.Min > 4 { + if magicHeader.Min == magicHeader.Max { + sendf("h%d=%d", i+1, magicHeader.Min) + continue + } + + sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max) + } + } specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() for _, field := range specialJunkIpcFields { @@ -201,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { deviceConfig := true tempAwg := awg.Protocol{} + tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4) + scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -369,43 +368,38 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) device.log.Verbosef("UAPI: Updating transport_packet_junk_size") tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize tempAwg.Cfg.IsSet = true - case "h1": initMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader + tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader tempAwg.Cfg.IsSet = true - case "h2": responseMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader + tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader tempAwg.Cfg.IsSet = true - case "h3": cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader + tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader tempAwg.Cfg.IsSet = true - case "h4": transportMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader + tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader tempAwg.Cfg.IsSet = true - case "i1", "i2", "i3", "i4", "i5": if len(value) == 0 { device.log.Verbosef("UAPI: received empty %s", key)