mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-30 08:32:50 +02:00
chore: restructure code and finish impl
This commit is contained in:
parent
6992e18755
commit
699bd240cc
5 changed files with 79 additions and 79 deletions
|
@ -18,19 +18,15 @@ type Cfg struct {
|
|||
CookieReplyHeaderJunkSize int
|
||||
TransportHeaderJunkSize int
|
||||
|
||||
InitPacketMagicHeader MagicHeader
|
||||
ResponsePacketMagicHeader MagicHeader
|
||||
UnderloadPacketMagicHeader MagicHeader
|
||||
TransportPacketMagicHeader MagicHeader
|
||||
MagicHeaders MagicHeaders
|
||||
}
|
||||
|
||||
type Protocol struct {
|
||||
IsOn abool.AtomicBool
|
||||
// TODO: revision the need of the mutex
|
||||
Mux sync.RWMutex
|
||||
Cfg Cfg
|
||||
JunkCreator JunkCreator
|
||||
MagicHeaders MagicHeaders
|
||||
Mux sync.RWMutex
|
||||
Cfg Cfg
|
||||
JunkCreator JunkCreator
|
||||
|
||||
HandshakeHandler SpecialHandshakeHandler
|
||||
}
|
||||
|
@ -80,9 +76,9 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
|
|||
}
|
||||
|
||||
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
|
||||
for _, limit := range protocol.MagicHeaders.headerValues {
|
||||
if limit.Min <= msgType && msgType <= limit.Max {
|
||||
return limit.Min, nil
|
||||
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
|
||||
return magicHeader.Min, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,5 +86,5 @@ func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
|
|||
}
|
||||
|
||||
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
|
||||
return protocol.MagicHeaders.Get(defaultMsgType)
|
||||
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
|
|||
}
|
||||
|
||||
type MagicHeaders struct {
|
||||
headerValues []MagicHeader
|
||||
Values []MagicHeader
|
||||
randomGenerator RandomNumberGenerator[uint32]
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,7 @@ func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
|
||||
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
|
||||
}
|
||||
|
||||
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
|
||||
|
@ -89,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
|
|||
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
|
||||
}
|
||||
|
||||
return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil
|
||||
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
|
||||
}
|
||||
|
|
|
@ -378,7 +378,7 @@ func TestNewMagicHeaders(t *testing.T) {
|
|||
require.Equal(t, MagicHeaders{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.magicHeaders, result.headerValues)
|
||||
require.Equal(t, tt.magicHeaders, result.Values)
|
||||
require.NotNil(t, result.randomGenerator)
|
||||
}
|
||||
})
|
||||
|
@ -469,7 +469,7 @@ func TestMagicHeaders_Get(t *testing.T) {
|
|||
t.Parallel()
|
||||
// Create a new instance with mock PRNG for each test
|
||||
testMagicHeaders := MagicHeaders{
|
||||
headerValues: headers,
|
||||
Values: headers,
|
||||
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
|
||||
}
|
||||
|
||||
|
|
|
@ -580,6 +580,7 @@ func (device *Device) BindClose() error {
|
|||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) isAWG() bool {
|
||||
return device.version >= VersionAwg
|
||||
}
|
||||
|
@ -599,7 +600,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
|
||||
var errs []error
|
||||
|
||||
isASecOn := false
|
||||
isAwgOn := false
|
||||
device.awg.Mux.Lock()
|
||||
if tempAwg.Cfg.JunkPacketCount < 0 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
|
@ -610,12 +611,12 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
|
||||
if tempAwg.Cfg.JunkPacketCount != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
|
||||
if tempAwg.Cfg.JunkPacketMinSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
if device.awg.Cfg.JunkPacketCount > 0 &&
|
||||
|
@ -645,57 +646,72 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
limits := make([]awg.MagicHeader, 4)
|
||||
magicHeaders := make([]awg.MagicHeader, 4)
|
||||
|
||||
if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 {
|
||||
isASecOn = true
|
||||
if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"magic headers should have 4 values; got: %d",
|
||||
len(tempAwg.Cfg.MagicHeaders.Values),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
|
||||
isAwgOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader
|
||||
limits[0] = tempAwg.Cfg.InitPacketMagicHeader
|
||||
MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min
|
||||
magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
|
||||
|
||||
MessageInitiationType = magicHeaders[0].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
|
||||
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 {
|
||||
isASecOn = true
|
||||
if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader
|
||||
MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min
|
||||
limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader
|
||||
magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
|
||||
MessageResponseType = magicHeaders[1].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
|
||||
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 {
|
||||
isASecOn = true
|
||||
if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min
|
||||
limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader
|
||||
magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
|
||||
MessageCookieReplyType = magicHeaders[2].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
|
||||
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 {
|
||||
isASecOn = true
|
||||
if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader
|
||||
MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min
|
||||
limits[3] = tempAwg.Cfg.TransportPacketMagicHeader
|
||||
magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
|
||||
MessageTransportType = magicHeaders[3].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
|
||||
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
|
||||
}
|
||||
|
||||
var err error
|
||||
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
|
||||
if err != nil {
|
||||
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
|
||||
}
|
||||
|
||||
isSameHeaderMap := map[uint32]struct{}{
|
||||
|
@ -705,12 +721,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
MessageTransportType: {},
|
||||
}
|
||||
|
||||
var err error
|
||||
device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits)
|
||||
if err != nil {
|
||||
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
|
||||
}
|
||||
|
||||
// size will be different if same values
|
||||
if len(isSameHeaderMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
|
@ -739,7 +749,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
|
||||
|
@ -757,7 +767,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
|
||||
|
@ -775,7 +785,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
|
||||
|
@ -793,7 +803,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
|
||||
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
isSameSizeMap := map[int]struct{}{
|
||||
|
@ -829,7 +839,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
|||
}
|
||||
}
|
||||
|
||||
device.awg.IsOn.SetTo(isASecOn)
|
||||
device.awg.IsOn.SetTo(isAwgOn)
|
||||
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
|
||||
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
|
|
|
@ -120,19 +120,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
|
||||
}
|
||||
// TODO:
|
||||
// if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
|
||||
// sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
|
||||
// }
|
||||
// if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
|
||||
// sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
|
||||
// }
|
||||
// if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
|
||||
// sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
|
||||
// }
|
||||
// if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
||||
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||
// }
|
||||
for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min > 4 {
|
||||
if magicHeader.Min == magicHeader.Max {
|
||||
sendf("h%d=%d", i+1, magicHeader.Min)
|
||||
continue
|
||||
}
|
||||
|
||||
sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
|
||||
}
|
||||
}
|
||||
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
|
@ -201,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
deviceConfig := true
|
||||
|
||||
tempAwg := awg.Protocol{}
|
||||
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
@ -369,43 +368,38 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
|
||||
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "h1":
|
||||
initMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader
|
||||
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "h2":
|
||||
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader
|
||||
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "h3":
|
||||
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader
|
||||
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "h4":
|
||||
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader
|
||||
tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "i1", "i2", "i3", "i4", "i5":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
|
|
Loading…
Add table
Reference in a new issue