From 0dcd528694ad8c9babd807cb38bce2aea64d3126 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 10 Feb 2025 10:05:42 +0100
Subject: [PATCH] rename vars and complete testing
---
adapter/lua.go | 55 +++++++++++++++++++++++++++-------------
adapter/lua_test.go | 22 ++++++++++++++++
device/device.go | 26 +++++++++++--------
device/device_test.go | 15 +++++++++--
device/noise-protocol.go | 20 +++++++--------
device/receive.go | 20 +++++----------
device/send.go | 25 +++++++++---------
device/uapi.go | 6 ++---
8 files changed, 121 insertions(+), 68 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 371878d..4677345 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -8,9 +8,9 @@ import (
"github.com/aarzilli/golua/lua"
)
-// TODO: aSec sync is enough?
type Lua struct {
- state *lua.State
+ generateState *lua.State
+ parseState *lua.State
packetCounter atomic.Int64
base64LuaCode string
}
@@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) {
if err != nil {
return nil, err
}
- fmt.Println(string(luaCode))
+ strLuaCode := string(luaCode)
+
+ generateState, err := initState(strLuaCode)
+ if err != nil {
+ return nil, err
+ }
+ parseState, err := initState(strLuaCode)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Lua{
+ generateState: generateState,
+ parseState: parseState,
+ base64LuaCode: params.Base64LuaCode,
+ }, nil
+}
+
+func initState(luaCode string) (*lua.State, error) {
state := lua.NewState()
state.OpenLibs()
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
- return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
+ return state, nil
}
func (l *Lua) Close() {
- l.state.Close()
+ l.generateState.Close()
+ l.parseState.Close()
}
+// Only thread safe if used by wg packet creation which happens independably
func (l *Lua) Generate(
msgType int64,
data []byte,
) ([]byte, error) {
- l.state.GetGlobal("d_gen")
+ l.generateState.GetGlobal("d_gen")
- l.state.PushInteger(msgType)
- l.state.PushBytes(data)
- l.state.PushInteger(l.packetCounter.Add(1))
+ l.generateState.PushInteger(msgType)
+ l.generateState.PushBytes(data)
+ l.generateState.PushInteger(l.packetCounter.Add(1))
- if err := l.state.Call(3, 1); err != nil {
+ if err := l.generateState.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := l.state.ToBytes(-1)
- l.state.Pop(1)
+ result := l.generateState.ToBytes(-1)
+ l.generateState.Pop(1)
return result, nil
}
+// Only thread safe if used by wg packet receive which happens independably
func (l *Lua) Parse(data []byte) ([]byte, error) {
- l.state.GetGlobal("d_parse")
+ l.parseState.GetGlobal("d_parse")
- l.state.PushBytes(data)
- if err := l.state.Call(1, 1); err != nil {
+ l.parseState.PushBytes(data)
+ if err := l.parseState.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := l.state.ToBytes(-1)
- l.state.Pop(1)
+ result := l.parseState.ToBytes(-1)
+ l.parseState.Pop(1)
return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 1a3d039..88389f5 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) {
}
})
}
+
+
+var R []byte
+var l = newLua()
+func BenchmarkLuaGenerate(b *testing.B) {
+ for i := 0; i < b.N; i++ { var err error
+ R, err = l.Generate(1, []byte("test"))
+ if err != nil {
+ return
+ }
+ }
+}
+
+func BenchmarkLuaParse(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var err error
+ R, err = l.Parse([]byte("1headertest"))
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/device/device.go b/device/device.go
index b16b460..8a96ffd 100644
--- a/device/device.go
+++ b/device/device.go
@@ -98,11 +98,11 @@ type Device struct {
type awgType struct {
isASecOn abool.AtomicBool
- aSecMux sync.RWMutex
+ mutex sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
- luaAdapter *adapter.Lua
+ codec *adapter.Lua
}
type aSecCfgType struct {
@@ -434,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
- if device.awg.luaAdapter != nil {
- device.awg.luaAdapter.Close()
- device.awg.luaAdapter = nil
+ if device.awg.codec != nil {
+ device.awg.codec.Close()
+ device.awg.codec = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
}
isASecOn := false
- device.awg.aSecMux.Lock()
+ device.awg.mutex.Lock()
if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
- device.awg.luaAdapter = tempAwgType.luaAdapter
- device.awg.aSecMux.Unlock()
+ device.awg.codec = tempAwgType.codec
+ device.awg.mutex.Unlock()
return err
}
-func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
- if device.awg.luaAdapter != nil {
+func (device *Device) isCodecActive() bool {
+ return device.awg.codec != nil
+}
+
+func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
+ if device.isCodecActive() {
var err error
- packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
+ packet, err = device.awg.codec.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
diff --git a/device/device_test.go b/device/device_test.go
index 367f097..0f42c46 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
+ /*
+ head = "headheadhead"
+ tail = "tailtailtail"
+ function d_gen(msg_type, data, counter)
+ return head .. data .. tail
+ end
+
+ function d_parse(data)
+ return string.match(data, head.. "(.-)".. tail)
+ end
+ */
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
+ "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
+ "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 2dc0d93..75352c0 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
if msg.Type != MessageInitiationType {
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
return nil
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
msg.Type = MessageResponseType
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
if msg.Type != MessageResponseType {
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
return nil
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index 4ac5783..3df269e 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -12,7 +12,6 @@ import (
"net"
"sync"
"time"
- "unsafe"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
@@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
}
packet := bufsArrs[i][:size]
- if device.awg.luaAdapter != nil {
- realPacket, err := device.awg.luaAdapter.Parse(packet)
-
- packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array
- // Copy data from realPacket to the memory pointed to by bufsArrs[i]
+ if device.isCodecActive() {
+ realPacket, err := device.awg.codec.Parse(packet)
+ copy(packet, realPacket)
size = len(realPacket)
- for j := 0; j < size; j++ {
- *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j]
- }
packet = bufsArrs[i][:size]
if err != nil {
device.log.Verbosef(
@@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
// handle cookie fields and ratelimiting
@@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index 5acd81c..d6352e9 100644
--- a/device/send.go
+++ b/device/send.go
@@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
)
return err
}
+
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
}
var buf [MessageInitiationSize]byte
@@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
@@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
@@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie(
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
- if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
+ if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
@@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) {
nil,
)
var err error
- if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
+ if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}
diff --git a/device/uapi.go b/device/uapi.go
index a2bb21b..850cc6c 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- if device.awg.luaAdapter != nil {
- sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
+ if device.awg.codec != nil {
+ sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
@@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
var err error
- tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
+ tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
Base64LuaCode: value,
})
if err != nil {