rename vars and complete testing

This commit is contained in:
Mark Puha 2025-02-10 10:05:42 +01:00
parent 068e146538
commit 0dcd528694
8 changed files with 121 additions and 68 deletions

View file

@ -8,9 +8,9 @@ import (
"github.com/aarzilli/golua/lua"
)
// TODO: aSec sync is enough?
type Lua struct {
state *lua.State
generateState *lua.State
parseState *lua.State
packetCounter atomic.Int64
base64LuaCode string
}
@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) {
if err != nil {
return nil, err
}
fmt.Println(string(luaCode))
strLuaCode := string(luaCode)
generateState, err := initState(strLuaCode)
if err != nil {
return nil, err
}
parseState, err := initState(strLuaCode)
if err != nil {
return nil, err
}
return &Lua{
generateState: generateState,
parseState: parseState,
base64LuaCode: params.Base64LuaCode,
}, nil
}
func initState(luaCode string) (*lua.State, error) {
state := lua.NewState()
state.OpenLibs()
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
return state, nil
}
func (l *Lua) Close() {
l.state.Close()
l.generateState.Close()
l.parseState.Close()
}
// Only thread safe if used by wg packet creation which happens independably
func (l *Lua) Generate(
msgType int64,
data []byte,
) ([]byte, error) {
l.state.GetGlobal("d_gen")
l.generateState.GetGlobal("d_gen")
l.state.PushInteger(msgType)
l.state.PushBytes(data)
l.state.PushInteger(l.packetCounter.Add(1))
l.generateState.PushInteger(msgType)
l.generateState.PushBytes(data)
l.generateState.PushInteger(l.packetCounter.Add(1))
if err := l.state.Call(3, 1); err != nil {
if err := l.generateState.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := l.state.ToBytes(-1)
l.state.Pop(1)
result := l.generateState.ToBytes(-1)
l.generateState.Pop(1)
return result, nil
}
// Only thread safe if used by wg packet receive which happens independably
func (l *Lua) Parse(data []byte) ([]byte, error) {
l.state.GetGlobal("d_parse")
l.parseState.GetGlobal("d_parse")
l.state.PushBytes(data)
if err := l.state.Call(1, 1); err != nil {
l.parseState.PushBytes(data)
if err := l.parseState.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := l.state.ToBytes(-1)
l.state.Pop(1)
result := l.parseState.ToBytes(-1)
l.parseState.Pop(1)
return result, nil
}

View file

@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) {
}
})
}
var R []byte
var l = newLua()
func BenchmarkLuaGenerate(b *testing.B) {
for i := 0; i < b.N; i++ { var err error
R, err = l.Generate(1, []byte("test"))
if err != nil {
return
}
}
}
func BenchmarkLuaParse(b *testing.B) {
for i := 0; i < b.N; i++ {
var err error
R, err = l.Parse([]byte("1headertest"))
if err != nil {
return
}
}
}

View file

@ -98,11 +98,11 @@ type Device struct {
type awgType struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
mutex sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
luaAdapter *adapter.Lua
codec *adapter.Lua
}
type aSecCfgType struct {
@ -434,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
if device.awg.luaAdapter != nil {
device.awg.luaAdapter.Close()
device.awg.luaAdapter = nil
if device.awg.codec != nil {
device.awg.codec.Close()
device.awg.codec = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
}
isASecOn := false
device.awg.aSecMux.Lock()
device.awg.mutex.Lock()
if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
device.awg.luaAdapter = tempAwgType.luaAdapter
device.awg.aSecMux.Unlock()
device.awg.codec = tempAwgType.codec
device.awg.mutex.Unlock()
return err
}
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
if device.awg.luaAdapter != nil {
func (device *Device) isCodecActive() bool {
return device.awg.codec != nil
}
func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
if device.isCodecActive() {
var err error
packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
packet, err = device.awg.codec.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err

View file

@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
/*
head = "headheadhead"
tail = "tailtailtail"
function d_gen(msg_type, data, counter)
return head .. data .. tail
end
function d_parse(data)
return string.match(data, head.. "(.-)".. tail)
end
*/
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",

View file

@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
if msg.Type != MessageInitiationType {
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
return nil
}
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
msg.Type = MessageResponseType
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
if msg.Type != MessageResponseType {
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
return nil
}
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
// lookup handshake by receiver

View file

@ -12,7 +12,6 @@ import (
"net"
"sync"
"time"
"unsafe"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
}
packet := bufsArrs[i][:size]
if device.awg.luaAdapter != nil {
realPacket, err := device.awg.luaAdapter.Parse(packet)
packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array
// Copy data from realPacket to the memory pointed to by bufsArrs[i]
if device.isCodecActive() {
realPacket, err := device.awg.codec.Parse(packet)
copy(packet, realPacket)
size = len(realPacket)
for j := 0; j < size; j++ {
*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j]
}
packet = bufsArrs[i][:size]
if err != nil {
device.log.Verbosef(
@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.awg.aSecMux.RLock()
device.awg.mutex.RLock()
// handle cookie fields and ratelimiting
@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.awg.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.awg.aSecMux.RLock()
peer.device.awg.mutex.RLock()
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
peer.device.awg.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
peer.device.awg.aSecMux.RLock()
peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.awg.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.awg.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
}
var buf [MessageInitiationSize]byte
@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.awg.aSecMux.RLock()
peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.awg.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.awg.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie(
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) {
nil,
)
var err error
if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}

View file

@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.awg.luaAdapter != nil {
sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
if device.awg.codec != nil {
sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
var err error
tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
Base64LuaCode: value,
})
if err != nil {