From 6992e187552d434cb67e2f3f0c178cdbaff39cee Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Thu, 10 Jul 2025 19:56:03 +0200 Subject: [PATCH] chore: rename variables --- device/awg/awg.go | 45 ++++---- device/awg/junk_creator.go | 30 ++--- device/awg/junk_creator_test.go | 14 +-- device/device.go | 195 +++++++++++++++++--------------- device/noise-protocol.go | 24 ++-- device/receive.go | 13 ++- device/send.go | 10 +- device/uapi.go | 72 ++++++------ 8 files changed, 206 insertions(+), 197 deletions(-) diff --git a/device/awg/awg.go b/device/awg/awg.go index 336617e..29a4b62 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -8,7 +8,7 @@ import ( "github.com/tevino/abool" ) -type aSecCfgType struct { +type Cfg struct { IsSet bool JunkPacketCount int JunkPacketMinSize int @@ -25,42 +25,42 @@ type aSecCfgType struct { } type Protocol struct { - IsASecOn abool.AtomicBool + IsOn abool.AtomicBool // TODO: revision the need of the mutex - ASecMux sync.RWMutex - ASecCfg aSecCfgType - JunkCreator junkCreator + Mux sync.RWMutex + Cfg Cfg + JunkCreator JunkCreator MagicHeaders MagicHeaders HandshakeHandler SpecialHandshakeHandler } func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { - protocol.ASecMux.RLock() - defer protocol.ASecMux.RUnlock() + protocol.Mux.RLock() + defer protocol.Mux.RUnlock() - return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0) + return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0) } func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { - protocol.ASecMux.RLock() - defer protocol.ASecMux.RUnlock() + protocol.Mux.RLock() + defer protocol.Mux.RUnlock() - return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0) + return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0) } func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { - protocol.ASecMux.RLock() - defer protocol.ASecMux.RUnlock() + protocol.Mux.RLock() + defer protocol.Mux.RUnlock() - return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0) + return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0) } func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { - protocol.ASecMux.RLock() - defer protocol.ASecMux.RUnlock() + protocol.Mux.RLock() + defer protocol.Mux.RUnlock() - return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) + return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize) } func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) { @@ -68,7 +68,6 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, return nil, nil } - var junk []byte buf := make([]byte, 0, junkSize+extraSize) writer := bytes.NewBuffer(buf[:0]) @@ -77,19 +76,17 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, return nil, fmt.Errorf("append junk: %w", err) } - junk = writer.Bytes() - - return junk, nil + return writer.Bytes(), nil } -func (protocol *Protocol) GetMagicHeaderMinFor(msgTypeRange uint32) (uint32, error) { +func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) { for _, limit := range protocol.MagicHeaders.headerValues { - if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max { + if limit.Min <= msgType && msgType <= limit.Max { return limit.Min, nil } } - return 0, fmt.Errorf("no header for range: %d", msgTypeRange) + return 0, fmt.Errorf("no header for value: %d", msgType) } func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) { diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go index b41441a..8ba2918 100644 --- a/device/awg/junk_creator.go +++ b/device/awg/junk_creator.go @@ -5,23 +5,23 @@ import ( "fmt" ) -type junkCreator struct { - aSecCfg aSecCfgType +type JunkCreator struct { + cfg Cfg randomGenerator PRNG[int] } // TODO: refactor param to only pass the junk related params -func NewJunkCreator(aSecCfg aSecCfgType) junkCreator { - return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()} +func NewJunkCreator(cfg Cfg) JunkCreator { + return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()} } -// Should be called with aSecMux RLocked -func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) { - if jc.aSecCfg.JunkPacketCount == 0 { +// Should be called with awg mux RLocked +func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) { + if jc.cfg.JunkPacketCount == 0 { return } - for range jc.aSecCfg.JunkPacketCount { + for range jc.cfg.JunkPacketCount { packetSize := jc.randomPacketSize() junk := jc.randomJunkWithSize(packetSize) *junks = append(*junks, junk) @@ -29,13 +29,13 @@ func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) { return } -// Should be called with aSecMux RLocked -func (jc *junkCreator) randomPacketSize() int { - return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize) +// Should be called with awg mux RLocked +func (jc *JunkCreator) randomPacketSize() int { + return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize) } -// Should be called with aSecMux RLocked -func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { +// Should be called with awg mux RLocked +func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error { headerJunk := jc.randomJunkWithSize(size) _, err := writer.Write(headerJunk) if err != nil { @@ -44,7 +44,7 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { return nil } -// Should be called with aSecMux RLocked -func (jc *junkCreator) randomJunkWithSize(size int) []byte { +// Should be called with awg mux RLocked +func (jc *JunkCreator) randomJunkWithSize(size int) []byte { return jc.randomGenerator.ReadSize(size) } diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go index 3553a93..33532e4 100644 --- a/device/awg/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -6,8 +6,8 @@ import ( "testing" ) -func setUpJunkCreator() junkCreator { - jc := NewJunkCreator(aSecCfgType{ +func setUpJunkCreator() JunkCreator { + jc := NewJunkCreator(Cfg{ IsSet: true, JunkPacketCount: 5, JunkPacketMinSize: 500, @@ -27,7 +27,7 @@ func setUpJunkCreator() junkCreator { func Test_junkCreator_createJunkPackets(t *testing.T) { jc := setUpJunkCreator() t.Run("valid", func(t *testing.T) { - got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) + got := make([][]byte, 0, jc.cfg.JunkPacketCount) jc.CreateJunkPackets(&got) seen := make(map[string]bool) for _, junk := range got { @@ -62,13 +62,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { jc := setUpJunkCreator() for range [30]struct{}{} { t.Run("valid", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || - got > jc.aSecCfg.JunkPacketMaxSize { + if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got || + got > jc.cfg.JunkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.aSecCfg.JunkPacketMinSize, - jc.aSecCfg.JunkPacketMaxSize, + jc.cfg.JunkPacketMinSize, + jc.cfg.JunkPacketMaxSize, ) } }) diff --git a/device/device.go b/device/device.go index 1a2b726..fe71037 100644 --- a/device/device.go +++ b/device/device.go @@ -593,105 +593,105 @@ func (device *Device) resetProtocol() { } func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { - if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { + if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet { return nil } var errs []error isASecOn := false - device.awg.ASecMux.Lock() - if tempAwg.ASecCfg.JunkPacketCount < 0 { + device.awg.Mux.Lock() + if tempAwg.Cfg.JunkPacketCount < 0 { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ), ) } - device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount - if tempAwg.ASecCfg.JunkPacketCount != 0 { + device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount + if tempAwg.Cfg.JunkPacketCount != 0 { isASecOn = true } - device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize - if tempAwg.ASecCfg.JunkPacketMinSize != 0 { + device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize + if tempAwg.Cfg.JunkPacketMinSize != 0 { isASecOn = true } - if device.awg.ASecCfg.JunkPacketCount > 0 && - tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { + if device.awg.Cfg.JunkPacketCount > 0 && + tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize { - tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work + tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work } - if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { - device.awg.ASecCfg.JunkPacketMinSize = 0 - device.awg.ASecCfg.JunkPacketMaxSize = 1 + if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize { + device.awg.Cfg.JunkPacketMinSize = 0 + device.awg.Cfg.JunkPacketMaxSize = 1 errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempAwg.ASecCfg.JunkPacketMaxSize, + tempAwg.Cfg.JunkPacketMaxSize, MaxSegmentSize, )) - } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { + } else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d", - tempAwg.ASecCfg.JunkPacketMaxSize, - tempAwg.ASecCfg.JunkPacketMinSize, + tempAwg.Cfg.JunkPacketMaxSize, + tempAwg.Cfg.JunkPacketMinSize, )) } else { - device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize + device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize } - if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { + if tempAwg.Cfg.JunkPacketMaxSize != 0 { isASecOn = true } limits := make([]awg.MagicHeader, 4) - if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 { + if tempAwg.Cfg.InitPacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader - limits[0] = tempAwg.ASecCfg.InitPacketMagicHeader - MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min + device.awg.Cfg.InitPacketMagicHeader = tempAwg.Cfg.InitPacketMagicHeader + limits[0] = tempAwg.Cfg.InitPacketMagicHeader + MessageInitiationType = device.awg.Cfg.InitPacketMagicHeader.Min } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType) } - if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { + if tempAwg.Cfg.ResponsePacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader - MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min - limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader + device.awg.Cfg.ResponsePacketMagicHeader = tempAwg.Cfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.Cfg.ResponsePacketMagicHeader.Min + limits[1] = tempAwg.Cfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType) } - if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 { + if tempAwg.Cfg.UnderloadPacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader - MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min - limits[2] = tempAwg.ASecCfg.UnderloadPacketMagicHeader + device.awg.Cfg.UnderloadPacketMagicHeader = tempAwg.Cfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.Cfg.UnderloadPacketMagicHeader.Min + limits[2] = tempAwg.Cfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType) } - if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 { + if tempAwg.Cfg.TransportPacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader - MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min - limits[3] = tempAwg.ASecCfg.TransportPacketMagicHeader + device.awg.Cfg.TransportPacketMagicHeader = tempAwg.Cfg.TransportPacketMagicHeader + MessageTransportType = device.awg.Cfg.TransportPacketMagicHeader.Min + limits[3] = tempAwg.Cfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -724,75 +724,75 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ) } - newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize + newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize if newInitSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.InitHeaderJunkSize, + tempAwg.Cfg.InitHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize + device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize } - if tempAwg.ASecCfg.InitHeaderJunkSize != 0 { + if tempAwg.Cfg.InitHeaderJunkSize != 0 { isASecOn = true } - newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize + newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize if newResponseSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.ResponseHeaderJunkSize, + tempAwg.Cfg.ResponseHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize + device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize } - if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 { + if tempAwg.Cfg.ResponseHeaderJunkSize != 0 { isASecOn = true } - newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize + newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize if newCookieSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.CookieReplyHeaderJunkSize, + tempAwg.Cfg.CookieReplyHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize + device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize } - if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 { isASecOn = true } - newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize + newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize if newTransportSize >= MaxSegmentSize { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.TransportHeaderJunkSize, + tempAwg.Cfg.TransportHeaderJunkSize, MaxSegmentSize, ), ) } else { - device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize + device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize } - if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 { + if tempAwg.Cfg.TransportHeaderJunkSize != 0 { isASecOn = true } @@ -815,10 +815,10 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ) } else { msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize, - MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize, - MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize, - MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize, + MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize, + MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize, + MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize, + MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize, } packetSizeToMsgType = map[int]uint32{ @@ -829,8 +829,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } } - device.awg.IsASecOn.SetTo(isASecOn) - device.awg.JunkCreator = awg.NewJunkCreator(device.awg.ASecCfg) + device.awg.IsOn.SetTo(isASecOn) + device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg) if tempAwg.HandshakeHandler.IsSet { if err := tempAwg.HandshakeHandler.Validate(); err != nil { @@ -838,78 +838,89 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) } else { device.awg.HandshakeHandler = tempAwg.HandshakeHandler - device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount - device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount + device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount device.version = VersionAwgSpecialHandshake } } else { device.version = VersionAwg } - device.awg.ASecMux.Unlock() + device.awg.Mux.Unlock() return errors.Join(errs...) } -func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) { +func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) { // TODO: // if awg.WaitResponse.ShouldWait.IsSet() { // awg.WaitResponse.Channel <- struct{}{} // } - assumedMsgType, hasAssumedType := packetSizeToMsgType[size] - if !hasAssumedType { - return device.handleTransport(size, packet, bufsArrs) - } + expectedMsgType, isKnownSize := packetSizeToMsgType[size] + if !isKnownSize { + msgType, err := device.handleTransport(size, packet, buffer) - junkSize := msgTypeToJunkSize[assumedMsgType] + if err != nil { + return 0, fmt.Errorf("handle transport: %w", err) + } - // transport size can align with other header types; - // making sure we have the right msgType - msgType, err = device.getMsgType(packet, junkSize) - if err != nil { - return 0, fmt.Errorf("aSec: get msg type: %w", err) - } - - if msgType == assumedMsgType { - *packet = (*packet)[junkSize:] return msgType, nil } - device.log.Verbosef("transport packet lined up with another msg type") - - return device.handleTransport(size, packet, bufsArrs) -} - -func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { - msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) - msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeRange) + junkSize := msgTypeToJunkSize[expectedMsgType] + // transport size can align with other header types; + // making sure we have the right actualMsgType + actualMsgType, err := device.getMsgType(packet, junkSize) if err != nil { - return 0, fmt.Errorf("get limit min: %w", err) + return 0, fmt.Errorf("get msg type: %w", err) + } + + if actualMsgType == expectedMsgType { + *packet = (*packet)[junkSize:] + return actualMsgType, nil + } + + device.log.Verbosef("awg: transport packet lined up with another msg type") + + msgType, err := device.handleTransport(size, packet, buffer) + if err != nil { + return 0, fmt.Errorf("handle transport: %w", err) } return msgType, nil } -func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (uint32, error) { - transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize +func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { + msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) + msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue) - msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize) + if err != nil { + return 0, fmt.Errorf("get magic header min: %w", err) + } + + return msgType, nil +} + +func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) { + junkSize := device.awg.Cfg.TransportHeaderJunkSize + + msgType, err := device.getMsgType(packet, junkSize) if err != nil { return 0, fmt.Errorf("get msg type: %w", err) } if msgType != MessageTransportType { // probably a junk packet - return 0, fmt.Errorf("aSec: Received message with unknown type: %d", msgType) + return 0, fmt.Errorf("Received message with unknown type: %d", msgType) } - if transportJunkSize > 0 { - // remove junk from bufsArrs by shifting the packet + if junkSize > 0 { + // remove junk from buffer by shifting the packet // this buffer is also used for decryption, so it needs to be corrected - copy((*bufsArrs)[:size], (*packet)[transportJunkSize:]) - size -= transportJunkSize + copy((*buffer)[:size], (*packet)[junkSize:]) + size -= junkSize // need to reinitialize packet as well (*packet) = (*packet)[:size] } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index bce2d2f..bac089d 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -205,10 +205,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType) if err != nil { - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() return nil, fmt.Errorf("get message type: %w", err) } @@ -216,7 +216,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e Type: msgType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -270,13 +270,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() if msg.Type != MessageInitiationType { - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() return nil } - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -391,14 +391,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType) if err != nil { - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() return nil, fmt.Errorf("get message type: %w", err) } - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -448,13 +448,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() if msg.Type != MessageResponseType { - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() return nil } - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() // lookup handshake by receiver diff --git a/device/receive.go b/device/receive.go index 41ffa2b..8ace3a6 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -140,9 +140,10 @@ func (device *Device) RoutineReceiveIncoming( packet := bufsArrs[i][:size] var msgType uint32 if device.isAWG() { - msgType, err = device.Logic(size, &packet, bufsArrs[i]) + msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i]) + if err != nil { - device.log.Verbosef("awg device logic: %v", err) + device.log.Verbosef("awg: process packet: %v", err) continue } } else { @@ -232,7 +233,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -291,7 +292,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.awg.ASecMux.RLock() + device.awg.Mux.RLock() // handle cookie fields and ratelimiting @@ -449,7 +450,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.awg.ASecMux.RUnlock() + device.awg.Mux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index b87dda2..296e9a8 100644 --- a/device/send.go +++ b/device/send.go @@ -130,7 +130,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if peer.device.version >= VersionAwg { var junks [][]byte if peer.device.version == VersionAwgSpecialHandshake { - peer.device.awg.ASecMux.RLock() + peer.device.awg.Mux.RLock() // set junks depending on packet type junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() if junks == nil { @@ -141,13 +141,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } else { peer.device.log.Verbosef("%v - Special junks sent", peer) } - peer.device.awg.ASecMux.RUnlock() + peer.device.awg.Mux.RUnlock() } else { - junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount) + junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount) } - peer.device.awg.ASecMux.RLock() + peer.device.awg.Mux.RLock() peer.device.awg.JunkCreator.CreateJunkPackets(&junks) - peer.device.awg.ASecMux.RUnlock() + peer.device.awg.Mux.RUnlock() if len(junks) > 0 { err = peer.SendBuffers(junks) diff --git a/device/uapi.go b/device/uapi.go index d99a4ee..f4cb834 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -99,26 +99,26 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAWG() { - if device.awg.ASecCfg.JunkPacketCount != 0 { - sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) + if device.awg.Cfg.JunkPacketCount != 0 { + sendf("jc=%d", device.awg.Cfg.JunkPacketCount) } - if device.awg.ASecCfg.JunkPacketMinSize != 0 { - sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) + if device.awg.Cfg.JunkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize) } - if device.awg.ASecCfg.JunkPacketMaxSize != 0 { - sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) + if device.awg.Cfg.JunkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize) } - if device.awg.ASecCfg.InitHeaderJunkSize != 0 { - sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize) + if device.awg.Cfg.InitHeaderJunkSize != 0 { + sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize) } - if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 { - sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize) + if device.awg.Cfg.ResponseHeaderJunkSize != 0 { + sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize) } - if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 { - sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize) + if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 { + sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize) } - if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { - sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) + if device.awg.Cfg.TransportHeaderJunkSize != 0 { + sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize) } // TODO: // if device.awg.ASecCfg.InitPacketMagicHeader != 0 { @@ -313,8 +313,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempAwg.ASecCfg.JunkPacketCount = junkPacketCount - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.JunkPacketCount = junkPacketCount + tempAwg.Cfg.IsSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -322,8 +322,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize + tempAwg.Cfg.IsSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -331,8 +331,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize + tempAwg.Cfg.IsSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -340,8 +340,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize + tempAwg.Cfg.IsSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -349,8 +349,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize + tempAwg.Cfg.IsSet = true case "s3": cookieReplyPacketJunkSize, err := strconv.Atoi(value) @@ -358,8 +358,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size") - tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize + tempAwg.Cfg.IsSet = true case "s4": transportPacketJunkSize, err := strconv.Atoi(value) @@ -367,8 +367,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating transport_packet_junk_size") - tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize + tempAwg.Cfg.IsSet = true case "h1": initMagicHeader, err := awg.ParseMagicHeader(key, value) @@ -376,8 +376,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.InitPacketMagicHeader = initMagicHeader + tempAwg.Cfg.IsSet = true case "h2": responseMagicHeader, err := awg.ParseMagicHeader(key, value) @@ -385,8 +385,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.ResponsePacketMagicHeader = responseMagicHeader + tempAwg.Cfg.IsSet = true case "h3": cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value) @@ -394,8 +394,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader + tempAwg.Cfg.IsSet = true case "h4": transportMagicHeader, err := awg.ParseMagicHeader(key, value) @@ -403,8 +403,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader - tempAwg.ASecCfg.IsSet = true + tempAwg.Cfg.TransportPacketMagicHeader = transportMagicHeader + tempAwg.Cfg.IsSet = true case "i1", "i2", "i3", "i4", "i5": if len(value) == 0 {