mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-16 22:16:55 +02:00
rename vars and complete testing
This commit is contained in:
parent
068e146538
commit
0dcd528694
8 changed files with 121 additions and 68 deletions
|
@ -8,9 +8,9 @@ import (
|
|||
"github.com/aarzilli/golua/lua"
|
||||
)
|
||||
|
||||
// TODO: aSec sync is enough?
|
||||
type Lua struct {
|
||||
state *lua.State
|
||||
generateState *lua.State
|
||||
parseState *lua.State
|
||||
packetCounter atomic.Int64
|
||||
base64LuaCode string
|
||||
}
|
||||
|
@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println(string(luaCode))
|
||||
|
||||
strLuaCode := string(luaCode)
|
||||
|
||||
generateState, err := initState(strLuaCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parseState, err := initState(strLuaCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Lua{
|
||||
generateState: generateState,
|
||||
parseState: parseState,
|
||||
base64LuaCode: params.Base64LuaCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initState(luaCode string) (*lua.State, error) {
|
||||
state := lua.NewState()
|
||||
state.OpenLibs()
|
||||
|
||||
if err := state.DoString(string(luaCode)); err != nil {
|
||||
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
|
||||
}
|
||||
return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (l *Lua) Close() {
|
||||
l.state.Close()
|
||||
l.generateState.Close()
|
||||
l.parseState.Close()
|
||||
}
|
||||
|
||||
// Only thread safe if used by wg packet creation which happens independably
|
||||
func (l *Lua) Generate(
|
||||
msgType int64,
|
||||
data []byte,
|
||||
) ([]byte, error) {
|
||||
l.state.GetGlobal("d_gen")
|
||||
l.generateState.GetGlobal("d_gen")
|
||||
|
||||
l.state.PushInteger(msgType)
|
||||
l.state.PushBytes(data)
|
||||
l.state.PushInteger(l.packetCounter.Add(1))
|
||||
l.generateState.PushInteger(msgType)
|
||||
l.generateState.PushBytes(data)
|
||||
l.generateState.PushInteger(l.packetCounter.Add(1))
|
||||
|
||||
if err := l.state.Call(3, 1); err != nil {
|
||||
if err := l.generateState.Call(3, 1); err != nil {
|
||||
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
|
||||
}
|
||||
|
||||
result := l.state.ToBytes(-1)
|
||||
l.state.Pop(1)
|
||||
result := l.generateState.ToBytes(-1)
|
||||
l.generateState.Pop(1)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Only thread safe if used by wg packet receive which happens independably
|
||||
func (l *Lua) Parse(data []byte) ([]byte, error) {
|
||||
l.state.GetGlobal("d_parse")
|
||||
l.parseState.GetGlobal("d_parse")
|
||||
|
||||
l.state.PushBytes(data)
|
||||
if err := l.state.Call(1, 1); err != nil {
|
||||
l.parseState.PushBytes(data)
|
||||
if err := l.parseState.Call(1, 1); err != nil {
|
||||
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
|
||||
}
|
||||
|
||||
result := l.state.ToBytes(-1)
|
||||
l.state.Pop(1)
|
||||
result := l.parseState.ToBytes(-1)
|
||||
l.parseState.Pop(1)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
var R []byte
|
||||
var l = newLua()
|
||||
func BenchmarkLuaGenerate(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ { var err error
|
||||
R, err = l.Generate(1, []byte("test"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLuaParse(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var err error
|
||||
R, err = l.Parse([]byte("1headertest"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,11 +98,11 @@ type Device struct {
|
|||
|
||||
type awgType struct {
|
||||
isASecOn abool.AtomicBool
|
||||
aSecMux sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
junkCreator junkCreator
|
||||
|
||||
luaAdapter *adapter.Lua
|
||||
codec *adapter.Lua
|
||||
}
|
||||
|
||||
type aSecCfgType struct {
|
||||
|
@ -434,9 +434,9 @@ func (device *Device) Close() {
|
|||
|
||||
device.resetProtocol()
|
||||
|
||||
if device.awg.luaAdapter != nil {
|
||||
device.awg.luaAdapter.Close()
|
||||
device.awg.luaAdapter = nil
|
||||
if device.awg.codec != nil {
|
||||
device.awg.codec.Close()
|
||||
device.awg.codec = nil
|
||||
}
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
|
@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
|
|||
}
|
||||
|
||||
isASecOn := false
|
||||
device.awg.aSecMux.Lock()
|
||||
device.awg.mutex.Lock()
|
||||
if tempAwgType.aSecCfg.junkPacketCount < 0 {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
|
@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
|
|||
if device.awg.isASecOn.IsSet() {
|
||||
device.awg.junkCreator, err = NewJunkCreator(device)
|
||||
}
|
||||
device.awg.luaAdapter = tempAwgType.luaAdapter
|
||||
device.awg.aSecMux.Unlock()
|
||||
device.awg.codec = tempAwgType.codec
|
||||
device.awg.mutex.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
|
||||
if device.awg.luaAdapter != nil {
|
||||
func (device *Device) isCodecActive() bool {
|
||||
return device.awg.codec != nil
|
||||
}
|
||||
|
||||
func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
|
||||
if device.isCodecActive() {
|
||||
var err error
|
||||
packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
|
||||
packet, err = device.awg.codec.Generate(int64(msgType),packet)
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
|
||||
return nil, err
|
||||
|
|
|
@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
/*
|
||||
head = "headheadhead"
|
||||
tail = "tailtailtail"
|
||||
function d_gen(msg_type, data, counter)
|
||||
return head .. data .. tail
|
||||
end
|
||||
|
||||
function d_parse(data)
|
||||
return string.match(data, head.. "(.-)".. tail)
|
||||
end
|
||||
*/
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
|
||||
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
|
@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
|
||||
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
|
|
|
@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
var msg MessageResponse
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
msg.Type = MessageResponseType
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
|
@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
if msg.Type != MessageResponseType {
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
|
@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
if device.awg.luaAdapter != nil {
|
||||
realPacket, err := device.awg.luaAdapter.Parse(packet)
|
||||
|
||||
packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array
|
||||
// Copy data from realPacket to the memory pointed to by bufsArrs[i]
|
||||
if device.isCodecActive() {
|
||||
realPacket, err := device.awg.codec.Parse(packet)
|
||||
copy(packet, realPacket)
|
||||
size = len(realPacket)
|
||||
for j := 0; j < size; j++ {
|
||||
*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j]
|
||||
}
|
||||
packet = bufsArrs[i][:size]
|
||||
if err != nil {
|
||||
device.log.Verbosef(
|
||||
|
@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
default:
|
||||
}
|
||||
}
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
|
@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.awg.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
|
@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.awg.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
)
|
||||
return err
|
||||
}
|
||||
|
||||
var sendBuffer [][]byte
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.awg.aSecMux.RLock()
|
||||
peer.device.awg.mutex.RLock()
|
||||
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
|
||||
peer.device.awg.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
|
@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
peer.device.awg.aSecMux.RLock()
|
||||
peer.device.awg.mutex.RLock()
|
||||
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
peer.device.awg.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.awg.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
|
@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
|
||||
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
}
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.awg.aSecMux.RLock()
|
||||
peer.device.awg.mutex.RLock()
|
||||
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.awg.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.awg.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
}
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
|
||||
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie(
|
|||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet := writer.Bytes()
|
||||
if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
|
||||
if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
nil,
|
||||
)
|
||||
var err error
|
||||
if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
|
||||
if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.awg.luaAdapter != nil {
|
||||
sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
|
||||
if device.awg.codec != nil {
|
||||
sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
|
||||
}
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if device.awg.aSecCfg.junkPacketCount != 0 {
|
||||
|
@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
|
|||
case "lua_codec":
|
||||
device.log.Verbosef("UAPI: Updating lua_codec")
|
||||
var err error
|
||||
tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
|
||||
tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
|
||||
Base64LuaCode: value,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue