From 0dcd528694ad8c9babd807cb38bce2aea64d3126 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 10 Feb 2025 10:05:42 +0100 Subject: [PATCH] rename vars and complete testing --- adapter/lua.go | 55 +++++++++++++++++++++++++++------------- adapter/lua_test.go | 22 ++++++++++++++++ device/device.go | 26 +++++++++++-------- device/device_test.go | 15 +++++++++-- device/noise-protocol.go | 20 +++++++-------- device/receive.go | 20 +++++---------- device/send.go | 25 +++++++++--------- device/uapi.go | 6 ++--- 8 files changed, 121 insertions(+), 68 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 371878d..4677345 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -8,9 +8,9 @@ import ( "github.com/aarzilli/golua/lua" ) -// TODO: aSec sync is enough? type Lua struct { - state *lua.State + generateState *lua.State + parseState *lua.State packetCounter atomic.Int64 base64LuaCode string } @@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) { if err != nil { return nil, err } - fmt.Println(string(luaCode)) + strLuaCode := string(luaCode) + + generateState, err := initState(strLuaCode) + if err != nil { + return nil, err + } + parseState, err := initState(strLuaCode) + if err != nil { + return nil, err + } + + return &Lua{ + generateState: generateState, + parseState: parseState, + base64LuaCode: params.Base64LuaCode, + }, nil +} + +func initState(luaCode string) (*lua.State, error) { state := lua.NewState() state.OpenLibs() if err := state.DoString(string(luaCode)); err != nil { return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } - return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil + return state, nil } func (l *Lua) Close() { - l.state.Close() + l.generateState.Close() + l.parseState.Close() } +// Only thread safe if used by wg packet creation which happens independably func (l *Lua) Generate( msgType int64, data []byte, ) ([]byte, error) { - l.state.GetGlobal("d_gen") + l.generateState.GetGlobal("d_gen") - l.state.PushInteger(msgType) - l.state.PushBytes(data) - l.state.PushInteger(l.packetCounter.Add(1)) + l.generateState.PushInteger(msgType) + l.generateState.PushBytes(data) + l.generateState.PushInteger(l.packetCounter.Add(1)) - if err := l.state.Call(3, 1); err != nil { + if err := l.generateState.Call(3, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := l.state.ToBytes(-1) - l.state.Pop(1) + result := l.generateState.ToBytes(-1) + l.generateState.Pop(1) return result, nil } +// Only thread safe if used by wg packet receive which happens independably func (l *Lua) Parse(data []byte) ([]byte, error) { - l.state.GetGlobal("d_parse") + l.parseState.GetGlobal("d_parse") - l.state.PushBytes(data) - if err := l.state.Call(1, 1); err != nil { + l.parseState.PushBytes(data) + if err := l.parseState.Call(1, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := l.state.ToBytes(-1) - l.state.Pop(1) + result := l.parseState.ToBytes(-1) + l.parseState.Pop(1) return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 1a3d039..88389f5 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) { } }) } + + +var R []byte +var l = newLua() +func BenchmarkLuaGenerate(b *testing.B) { + for i := 0; i < b.N; i++ { var err error + R, err = l.Generate(1, []byte("test")) + if err != nil { + return + } + } +} + +func BenchmarkLuaParse(b *testing.B) { + for i := 0; i < b.N; i++ { + var err error + R, err = l.Parse([]byte("1headertest")) + if err != nil { + return + } + } +} diff --git a/device/device.go b/device/device.go index b16b460..8a96ffd 100644 --- a/device/device.go +++ b/device/device.go @@ -98,11 +98,11 @@ type Device struct { type awgType struct { isASecOn abool.AtomicBool - aSecMux sync.RWMutex + mutex sync.RWMutex aSecCfg aSecCfgType junkCreator junkCreator - luaAdapter *adapter.Lua + codec *adapter.Lua } type aSecCfgType struct { @@ -434,9 +434,9 @@ func (device *Device) Close() { device.resetProtocol() - if device.awg.luaAdapter != nil { - device.awg.luaAdapter.Close() - device.awg.luaAdapter = nil + if device.awg.codec != nil { + device.awg.codec.Close() + device.awg.codec = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { } isASecOn := false - device.awg.aSecMux.Lock() + device.awg.mutex.Lock() if tempAwgType.aSecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, @@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { if device.awg.isASecOn.IsSet() { device.awg.junkCreator, err = NewJunkCreator(device) } - device.awg.luaAdapter = tempAwgType.luaAdapter - device.awg.aSecMux.Unlock() + device.awg.codec = tempAwgType.codec + device.awg.mutex.Unlock() return err } -func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { - if device.awg.luaAdapter != nil { +func (device *Device) isCodecActive() bool { + return device.awg.codec != nil +} + +func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) { + if device.isCodecActive() { var err error - packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) + packet, err = device.awg.codec.Generate(int64(msgType),packet) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err diff --git a/device/device_test.go b/device/device_test.go index 367f097..0f42c46 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { } pub1, pub2 := key1.publicKey(), key2.publicKey() + /* + head = "headheadhead" + tail = "tailtailtail" + function d_gen(msg_type, data, counter) + return head .. data .. tail + end + + function d_parse(data) + return string.match(data, head.. "(.-)".. tail) + end + */ cfgs[0] = uapiCfg( "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", + "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK", "jc", "5", "jmin", "500", "jmax", "1000", @@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", + "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK", "jc", "5", "jmin", "500", "jmax", "1000", diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 2dc0d93..75352c0 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() if msg.Type != MessageInitiationType { - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() msg.Type = MessageResponseType - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() if msg.Type != MessageResponseType { - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() // lookup handshake by receiver diff --git a/device/receive.go b/device/receive.go index 4ac5783..3df269e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -12,7 +12,6 @@ import ( "net" "sync" "time" - "unsafe" "github.com/amnezia-vpn/amneziawg-go/conn" "golang.org/x/crypto/chacha20poly1305" @@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming( } packet := bufsArrs[i][:size] - if device.awg.luaAdapter != nil { - realPacket, err := device.awg.luaAdapter.Parse(packet) - - packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array - // Copy data from realPacket to the memory pointed to by bufsArrs[i] + if device.isCodecActive() { + realPacket, err := device.awg.codec.Parse(packet) + copy(packet, realPacket) size = len(realPacket) - for j := 0; j < size; j++ { - *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j] - } packet = bufsArrs[i][:size] if err != nil { device.log.Verbosef( @@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() // handle cookie fields and ratelimiting @@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 5acd81c..d6352e9 100644 --- a/device/send.go +++ b/device/send.go @@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { ) return err } + var sendBuffer [][]byte // so only packet processed for cookie generation var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() junks, err := peer.device.awg.junkCreator.createJunkPackets(peer) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() } var buf [MessageInitiationSize]byte @@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil { return err } @@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil { return err } @@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie( writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) packet := writer.Bytes() - if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil { + if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil { return err } @@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) { nil, ) var err error - if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { + if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil { continue } } diff --git a/device/uapi.go b/device/uapi.go index a2bb21b..850cc6c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - if device.awg.luaAdapter != nil { - sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode()) + if device.awg.codec != nil { + sendf("lua_codec=%s", device.awg.codec.Base64LuaCode()) } if device.isAdvancedSecurityOn() { if device.awg.aSecCfg.junkPacketCount != 0 { @@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e case "lua_codec": device.log.Verbosef("UAPI: Updating lua_codec") var err error - tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{ + tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{ Base64LuaCode: value, }) if err != nil {