diff --git a/device/device.go b/device/device.go index aaddcbc..f375663 100644 --- a/device/device.go +++ b/device/device.go @@ -92,14 +92,17 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger + + awg awgType +} +type awgType struct { isASecOn abool.AtomicBool aSecMux sync.RWMutex aSecCfg aSecCfgType junkCreator junkCreator luaAdapter *adapter.Lua - packetCounter atomic.Int64 } type aSecCfgType struct { @@ -431,9 +434,9 @@ func (device *Device) Close() { device.resetProtocol() - if device.luaAdapter != nil { - device.luaAdapter.Close() - device.luaAdapter = nil + if device.awg.luaAdapter != nil { + device.awg.luaAdapter.Close() + device.awg.luaAdapter = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -581,7 +584,7 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.isASecOn.IsSet() + return device.awg.isASecOn.IsSet() } func (device *Device) resetProtocol() { @@ -594,36 +597,36 @@ func (device *Device) resetProtocol() { func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if !tempASecCfg.isSet { - return err + return nil } isASecOn := false - device.aSecMux.Lock() + device.awg.aSecMux.Lock() if tempASecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ) } - device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount + device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount if tempASecCfg.junkPacketCount != 0 { isASecOn = true } - device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize + device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize if tempASecCfg.junkPacketMinSize != 0 { isASecOn = true } - if device.aSecCfg.junkPacketCount > 0 && + if device.awg.aSecCfg.junkPacketCount > 0 && tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { tempASecCfg.junkPacketMaxSize++ // to make rand gen work } if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.aSecCfg.junkPacketMinSize = 0 - device.aSecCfg.junkPacketMaxSize = 1 + device.awg.aSecCfg.junkPacketMinSize = 0 + device.awg.aSecCfg.junkPacketMaxSize = 1 if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, @@ -658,7 +661,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize } if tempASecCfg.junkPacketMaxSize != 0 { @@ -683,7 +686,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize } if tempASecCfg.initPacketJunkSize != 0 { @@ -708,7 +711,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize } if tempASecCfg.responsePacketJunkSize != 0 { @@ -718,8 +721,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.initPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader - MessageInitiationType = device.aSecCfg.initPacketMagicHeader + device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader + MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType @@ -728,8 +731,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.responsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader - MessageResponseType = device.aSecCfg.responsePacketMagicHeader + device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader + MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType @@ -738,8 +741,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.underloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader + MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType @@ -748,8 +751,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.transportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader - MessageTransportType = device.aSecCfg.transportPacketMagicHeader + device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader + MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -785,8 +788,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } } - newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize + newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize + newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize if newInitSize == newResponseSize { if err != nil { @@ -814,16 +817,18 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, + MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, + MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, MessageCookieReplyType: 0, MessageTransportType: 0, } } - device.isASecOn.SetTo(isASecOn) - device.junkCreator, err = NewJunkCreator(device) - device.aSecMux.Unlock() + device.awg.isASecOn.SetTo(isASecOn) + if device.awg.isASecOn.IsSet() { + device.awg.junkCreator, err = NewJunkCreator(device) + } + device.awg.aSecMux.Unlock() return err } diff --git a/device/junk_creator.go b/device/junk_creator.go index 85a5bbc..2bfb197 100644 --- a/device/junk_creator.go +++ b/device/junk_creator.go @@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) { // Should be called with aSecMux RLocked func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { + if jc.device.awg.aSecCfg.junkPacketCount == 0 { return nil, nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount) + for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { @@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) { func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.device.awg.aSecCfg.junkPacketMinSize } // Should be called with aSecMux RLocked diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go index 6f63360..96fa6d3 100644 --- a/device/junk_creator_test.go +++ b/device/junk_creator_test.go @@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { } for range [30]struct{}{} { t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || + got > jc.device.awg.aSecCfg.junkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.device.awg.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize, ) } }) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index d818238..2dc0d93 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() // lookup handshake by receiver diff --git a/device/receive.go b/device/receive.go index e790048..70801ef 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -138,8 +138,8 @@ func (device *Device) RoutineReceiveIncoming( // check size of packet packet := bufsArrs[i][:size] - if device.luaAdapter != nil { - packet, err = device.luaAdapter.Parse(packet) + if device.awg.luaAdapter != nil { + packet, err = device.awg.luaAdapter.Parse(packet) if err != nil { device.log.Verbosef("Couldn't parse message; reason: %v", err) continue @@ -251,7 +251,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -310,7 +310,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle cookie fields and ratelimiting @@ -462,7 +462,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index ac05313..297f19e 100644 --- a/device/send.go +++ b/device/send.go @@ -131,9 +131,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { // so only packet processed for cookie generation var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets(peer) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RLock() + junks, err := peer.device.awg.junkCreator.createJunkPackets(peer) + peer.device.awg.aSecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -153,19 +153,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -215,19 +215,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -548,9 +548,9 @@ func calculatePaddingSize(packetSize, mtu int) int { } func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { - if device.luaAdapter != nil { + if device.awg.luaAdapter != nil { var err error - packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1)) + packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err diff --git a/device/uapi.go b/device/uapi.go index 777bdda..1bce39b 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -98,32 +98,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.awg.aSecCfg.junkPacketCount != 0 { + sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.aSecCfg.junkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.aSecCfg.junkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.aSecCfg.initPacketJunkSize != 0 { + sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.aSecCfg.responsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.aSecCfg.initPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.aSecCfg.responsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.aSecCfg.transportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) } }