migrate awg specific code to struct

This commit is contained in:
Mark Puha 2025-02-09 09:36:53 +01:00
parent e08105f85b
commit 6b0bbcc75c
7 changed files with 95 additions and 90 deletions

View file

@ -92,14 +92,17 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
awg awgType
}
type awgType struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
luaAdapter *adapter.Lua
packetCounter atomic.Int64
}
type aSecCfgType struct {
@ -431,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
if device.luaAdapter != nil {
device.luaAdapter.Close()
device.luaAdapter = nil
if device.awg.luaAdapter != nil {
device.awg.luaAdapter.Close()
device.awg.luaAdapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@ -581,7 +584,7 @@ func (device *Device) BindClose() error {
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
return device.awg.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
@ -594,36 +597,36 @@ func (device *Device) resetProtocol() {
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
return nil
}
isASecOn := false
device.aSecMux.Lock()
device.awg.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
if device.awg.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
device.awg.aSecCfg.junkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@ -658,7 +661,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
@ -683,7 +686,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
@ -708,7 +711,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
@ -718,8 +721,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
@ -728,8 +731,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
@ -738,8 +741,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
@ -748,8 +751,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@ -785,8 +788,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
@ -814,16 +817,18 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
device.awg.isASecOn.SetTo(isASecOn)
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
device.awg.aSecMux.Unlock()
return err
}

View file

@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
if jc.device.awg.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
) + jc.device.awg.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked

View file

@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
got > jc.device.awg.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
jc.device.awg.aSecCfg.junkPacketMinSize,
jc.device.awg.aSecCfg.junkPacketMaxSize,
)
}
})

View file

@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
return nil
}
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
msg.Type = MessageResponseType
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
return nil
}
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
// lookup handshake by receiver

View file

@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -138,8 +138,8 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet
packet := bufsArrs[i][:size]
if device.luaAdapter != nil {
packet, err = device.luaAdapter.Parse(packet)
if device.awg.luaAdapter != nil {
packet, err = device.awg.luaAdapter.Parse(packet)
if err != nil {
device.log.Verbosef("Couldn't parse message; reason: %v", err)
continue
@ -251,7 +251,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@ -310,7 +310,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
device.awg.aSecMux.RLock()
// handle cookie fields and ratelimiting
@ -462,7 +462,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.aSecMux.RUnlock()
device.awg.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -131,9 +131,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets(peer)
peer.device.aSecMux.RUnlock()
peer.device.awg.aSecMux.RLock()
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
peer.device.awg.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@ -153,19 +153,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
peer.device.awg.aSecMux.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
peer.device.awg.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
@ -215,19 +215,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
peer.device.awg.aSecMux.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.awg.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@ -548,9 +548,9 @@ func calculatePaddingSize(packetSize, mtu int) int {
}
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
if device.luaAdapter != nil {
if device.awg.luaAdapter != nil {
var err error
packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err

View file

@ -98,32 +98,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
if device.awg.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
if device.awg.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
if device.awg.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
if device.awg.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
if device.awg.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
}
}