From 44217caa0d41884cc9a21b447fbfa43cceb1d2eb Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 8 Feb 2025 13:37:52 +0100 Subject: [PATCH 01/15] successfully call lua function from Go --- adapter/Makefile | 2 ++ adapter/lua.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 50 insertions(+) create mode 100644 adapter/Makefile create mode 100644 adapter/lua.go diff --git a/adapter/Makefile b/adapter/Makefile new file mode 100644 index 0000000..41bc7aa --- /dev/null +++ b/adapter/Makefile @@ -0,0 +1,2 @@ +all: + go run -tags lua54 lua.go diff --git a/adapter/lua.go b/adapter/lua.go new file mode 100644 index 0000000..8c56d4a --- /dev/null +++ b/adapter/lua.go @@ -0,0 +1,45 @@ +package main + +import ( + "encoding/base64" + "fmt" + + "github.com/aarzilli/golua/lua" +) + +func main() { + // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==` + // only d_gen + // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` + luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` + sDec, _ := base64.StdEncoding.DecodeString(luaB64) + fmt.Println(string(sDec)) + luaCode := sDec + L := lua.NewState() + L.OpenLibs() + defer L.Close() + + // Load and execute the Lua code + if err := L.DoString(string(luaCode)); err != nil { + fmt.Printf("Error loading Lua code: %v\n", err) + return + } + + // Push the function onto the stack + L.GetGlobal("D_gen") + + // Push the argument + L.PushString("data") + + if err := L.Call(1, 1); err != nil { + fmt.Printf("Error calling Lua function: %v\n", err) + return + } + + result := L.ToString(-1) + L.Pop(1) + + // Print the result + // fmt.Printf("Result: %x\n", []byte(result)) + fmt.Printf("Result: %s\n", result) +} diff --git a/go.mod b/go.mod index b03acb0..8511642 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( ) require ( + github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index 7b53725..119ba05 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e h1:ibMKBskN7uMCz9TJgfaIVYVdPyckXm0UjFDRSNV7XB0= +github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e/go.mod h1:hMjfaJVSqVnxenMlsxrq3Ni+vrm9Hs64tU4M7dhUoO4= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= From 55f4715a5055fd1aa84d640bb0574f9e64c5ba37 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 8 Feb 2025 14:16:03 +0100 Subject: [PATCH 02/15] create lua adapter and test --- adapter/lua.go | 79 +++++++++++++++++++++++++++++++-------------- adapter/lua_test.go | 60 ++++++++++++++++++++++++++++++++++ go.mod | 2 +- 3 files changed, 115 insertions(+), 26 deletions(-) create mode 100644 adapter/lua_test.go diff --git a/adapter/lua.go b/adapter/lua.go index 8c56d4a..a6aceb1 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -1,4 +1,4 @@ -package main +package adapter import ( "encoding/base64" @@ -7,39 +7,68 @@ import ( "github.com/aarzilli/golua/lua" ) -func main() { - // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==` - // only d_gen - // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` - luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` - sDec, _ := base64.StdEncoding.DecodeString(luaB64) - fmt.Println(string(sDec)) - luaCode := sDec - L := lua.NewState() - L.OpenLibs() - defer L.Close() +// TODO: aSec sync is enough? +type Lua struct { + state *lua.State +} + +type LuaParams struct { + LuaCode64 string +} + +func NewLua(params LuaParams) (*Lua, error) { + luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64) + if err != nil { + return nil, err + } + fmt.Println(string(luaCode)) + + state := lua.NewState() + state.OpenLibs() // Load and execute the Lua code - if err := L.DoString(string(luaCode)); err != nil { - fmt.Printf("Error loading Lua code: %v\n", err) - return + if err := state.DoString(string(luaCode)); err != nil { + return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } + return &Lua{state: state}, nil +} +func (l *Lua) Close() { + l.state.Close() +} + +func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { // Push the function onto the stack - L.GetGlobal("D_gen") + l.state.GetGlobal("D_gen") // Push the argument - L.PushString("data") + l.state.PushBytes(data) + l.state.PushInteger(counter) - if err := L.Call(1, 1); err != nil { - fmt.Printf("Error calling Lua function: %v\n", err) - return + if err := l.state.Call(2, 1); err != nil { + return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := L.ToString(-1) - L.Pop(1) + result := l.state.ToBytes(-1) + l.state.Pop(1) - // Print the result - // fmt.Printf("Result: %x\n", []byte(result)) - fmt.Printf("Result: %s\n", result) + fmt.Printf("Result: %s\n", string(result)) + return result, nil +} + +func (l *Lua) Parse(data []byte) ([]byte, error) { + // Push the function onto the stack + l.state.GetGlobal("D_parse") + + // Push the argument + l.state.PushBytes(data) + if err := l.state.Call(1, 1); err != nil { + return nil, fmt.Errorf("Error calling Lua function: %v\n", err) + } + + result := l.state.ToBytes(-1) + l.state.Pop(1) + + fmt.Printf("Result: %s\n", string(result)) + return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go new file mode 100644 index 0000000..1d35651 --- /dev/null +++ b/adapter/lua_test.go @@ -0,0 +1,60 @@ +package adapter + +import ( + "testing" +) + +func newLua() *Lua { + lua, _ := NewLua(LuaParams{ + /* + function D_gen(data, counter) + local header = "header" + return counter .. header .. data + end + + function D_parse(data) + local header = "10header" + return string.sub(data, #header+1) + end + */ + LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", + }) + return lua +} + +func TestLua_Generate(t *testing.T) { + t.Run("", func(t *testing.T) { + l := newLua() + defer l.Close() + got, err := l.Generate([]byte("test"), 10) + if err != nil { + t.Errorf( + "Lua.Generate() error = %v, wantErr %v", + err, + nil, + ) + return + } + + want := "10headertest" + if string(got) != want { + t.Errorf("Lua.Generate() = %v, want %v", string(got), want) + } + }) +} + +func TestLua_Parse(t *testing.T) { + t.Run("", func(t *testing.T) { + l := newLua() + defer l.Close() + got, err := l.Parse([]byte("10headertest")) + if err != nil { + t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil) + return + } + want := "test" + if string(got) != want { + t.Errorf("Lua.Parse() = %v, want %v", got, want) + } + }) +} diff --git a/go.mod b/go.mod index 8511642..435752a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go go 1.23 require ( + github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.21.0 golang.org/x/net v0.21.0 @@ -12,7 +13,6 @@ require ( ) require ( - github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) From 3015f3ea20573a0eada7372b45c255415736896f Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 8 Feb 2025 14:35:20 +0100 Subject: [PATCH 03/15] change function initials --- adapter/lua.go | 15 ++++----------- adapter/lua_test.go | 6 +++--- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index a6aceb1..47ce7b2 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -13,11 +13,11 @@ type Lua struct { } type LuaParams struct { - LuaCode64 string + Base64LuaCode string } func NewLua(params LuaParams) (*Lua, error) { - luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64) + luaCode, err := base64.StdEncoding.DecodeString(params.Base64LuaCode) if err != nil { return nil, err } @@ -26,7 +26,6 @@ func NewLua(params LuaParams) (*Lua, error) { state := lua.NewState() state.OpenLibs() - // Load and execute the Lua code if err := state.DoString(string(luaCode)); err != nil { return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } @@ -38,10 +37,8 @@ func (l *Lua) Close() { } func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { - // Push the function onto the stack - l.state.GetGlobal("D_gen") + l.state.GetGlobal("d_gen") - // Push the argument l.state.PushBytes(data) l.state.PushInteger(counter) @@ -52,15 +49,12 @@ func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { result := l.state.ToBytes(-1) l.state.Pop(1) - fmt.Printf("Result: %s\n", string(result)) return result, nil } func (l *Lua) Parse(data []byte) ([]byte, error) { - // Push the function onto the stack - l.state.GetGlobal("D_parse") + l.state.GetGlobal("d_parse") - // Push the argument l.state.PushBytes(data) if err := l.state.Call(1, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) @@ -69,6 +63,5 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { result := l.state.ToBytes(-1) l.state.Pop(1) - fmt.Printf("Result: %s\n", string(result)) return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 1d35651..08d5420 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -7,17 +7,17 @@ import ( func newLua() *Lua { lua, _ := NewLua(LuaParams{ /* - function D_gen(data, counter) + function d_gen(data, counter) local header = "header" return counter .. header .. data end - function D_parse(data) + function d_parse(data) local header = "10header" return string.sub(data, #header+1) end */ - LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", + Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", }) return lua } From 32470fa04ee4b131ffca0b13d55a6b362e50b782 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 8 Feb 2025 19:00:00 +0100 Subject: [PATCH 04/15] add codec generation/parsing --- device/device.go | 14 ++++++++++---- device/receive.go | 8 +++++++- device/send.go | 32 +++++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/device/device.go b/device/device.go index 1bedce3..8178667 100644 --- a/device/device.go +++ b/device/device.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + "github.com/amnezia-vpn/amneziawg-go/adapter" "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" @@ -92,11 +93,13 @@ type Device struct { closed chan struct{} log *Logger - isASecOn abool.AtomicBool - aSecMux sync.RWMutex - aSecCfg aSecCfgType - + isASecOn abool.AtomicBool + aSecMux sync.RWMutex + aSecCfg aSecCfgType junkCreator junkCreator + + luaAdapter *adapter.Lua + packetCounter atomic.Int64 } type aSecCfgType struct { @@ -428,6 +431,9 @@ func (device *Device) Close() { device.resetProtocol() + if device.luaAdapter != nil { + device.luaAdapter.Close() + } device.log.Verbosef("Device closed") close(device.closed) } diff --git a/device/receive.go b/device/receive.go index 66c1a32..e790048 100644 --- a/device/receive.go +++ b/device/receive.go @@ -137,8 +137,14 @@ func (device *Device) RoutineReceiveIncoming( } // check size of packet - packet := bufsArrs[i][:size] + if device.luaAdapter != nil { + packet, err = device.luaAdapter.Parse(packet) + if err != nil { + device.log.Verbosef("Couldn't parse message; reason: %v", err) + continue + } + } var msgType uint32 if device.isAdvancedSecurityOn() { if assumedMsgType, ok := packetSizeToMsgType[size]; ok { diff --git a/device/send.go b/device/send.go index 5c54d4d..c0b9ad4 100644 --- a/device/send.go +++ b/device/send.go @@ -175,6 +175,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) + if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + return err + } + peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() @@ -233,6 +237,10 @@ func (peer *Peer) SendHandshakeResponse() error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) + if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + return err + } + err = peer.BeginSymmetricSession() if err != nil { peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) @@ -277,8 +285,13 @@ func (device *Device) SendHandshakeCookie( var buf [MessageCookieReplySize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) + packet := writer.Bytes() + if packet, err = device.codecPacket(packet); err != nil { + return err + } + // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) return nil } @@ -534,6 +547,18 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } +func (device *Device) codecPacket(packet []byte) ([]byte, error) { + if device.luaAdapter != nil { + var err error + packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1)) + if err != nil { + device.log.Errorf("%v - Failed to run codec generate: %v", device, err) + return nil, err + } + } + return packet, nil +} + /* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * @@ -578,6 +603,11 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) + // TODO: check + var err error + if elem.packet, err = device.codecPacket(elem.packet); err != nil { + continue + } } elemsContainer.Unlock() } From 553060f804a4429356037143a3dacd71d2e74a19 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:21:01 +0100 Subject: [PATCH 05/15] Send msg type to code & add default msg types --- adapter/lua.go | 5 +++-- adapter/lua_test.go | 20 ++++++++++---------- device/device.go | 17 +++++++++-------- device/noise-protocol.go | 15 +++++++++++---- device/send.go | 14 +++++++------- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 47ce7b2..e44ca70 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -36,13 +36,14 @@ func (l *Lua) Close() { l.state.Close() } -func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { +func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) { l.state.GetGlobal("d_gen") + l.state.PushInteger(msgType) l.state.PushBytes(data) l.state.PushInteger(counter) - if err := l.state.Call(2, 1); err != nil { + if err := l.state.Call(3, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 08d5420..828b555 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -7,17 +7,17 @@ import ( func newLua() *Lua { lua, _ := NewLua(LuaParams{ /* - function d_gen(data, counter) - local header = "header" - return counter .. header .. data - end + function d_gen(msg_type, data, counter) + local header = "header" + return counter .. header .. data + end - function d_parse(data) - local header = "10header" - return string.sub(data, #header+1) - end + function d_parse(data) + local header = "10header" + return string.sub(data, #header+1) + end */ - Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", + Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK", }) return lua } @@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) { t.Run("", func(t *testing.T) { l := newLua() defer l.Close() - got, err := l.Generate([]byte("test"), 10) + got, err := l.Generate(1, []byte("test"), 10) if err != nil { t.Errorf( "Lua.Generate() error = %v, wantErr %v", diff --git a/device/device.go b/device/device.go index 8178667..aaddcbc 100644 --- a/device/device.go +++ b/device/device.go @@ -433,6 +433,7 @@ func (device *Device) Close() { if device.luaAdapter != nil { device.luaAdapter.Close() + device.luaAdapter = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -585,10 +586,10 @@ func (device *Device) isAdvancedSecurityOn() bool { func (device *Device) resetProtocol() { // restore default message type values - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 + MessageInitiationType = DefaultMessageInitiationType + MessageResponseType = DefaultMessageResponseType + MessageCookieReplyType = DefaultMessageCookieReplyType + MessageTransportType = DefaultMessageTransportType } func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { @@ -721,7 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageInitiationType = device.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = 1 + MessageInitiationType = DefaultMessageInitiationType } if tempASecCfg.responsePacketMagicHeader > 4 { @@ -731,7 +732,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageResponseType = device.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = 2 + MessageResponseType = DefaultMessageResponseType } if tempASecCfg.underloadPacketMagicHeader > 4 { @@ -741,7 +742,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = 3 + MessageCookieReplyType = DefaultMessageCookieReplyType } if tempASecCfg.transportPacketMagicHeader > 4 { @@ -751,7 +752,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageTransportType = device.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = 4 + MessageTransportType = DefaultMessageTransportType } isSameMap := map[uint32]bool{} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1289249..d818238 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( diff --git a/device/send.go b/device/send.go index c0b9ad4..ac05313 100644 --- a/device/send.go +++ b/device/send.go @@ -175,7 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil { return err } @@ -237,7 +237,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil { return err } @@ -286,7 +286,7 @@ func (device *Device) SendHandshakeCookie( writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) packet := writer.Bytes() - if packet, err = device.codecPacket(packet); err != nil { + if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil { return err } @@ -547,10 +547,10 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } -func (device *Device) codecPacket(packet []byte) ([]byte, error) { +func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { if device.luaAdapter != nil { var err error - packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1)) + packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1)) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err @@ -603,9 +603,9 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - // TODO: check + var err error - if elem.packet, err = device.codecPacket(elem.packet); err != nil { + if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { continue } } From e08105f85b8b1da6ae558d5c14672c34e2338a47 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:36:41 +0100 Subject: [PATCH 06/15] migrate counter --- adapter/lua.go | 11 ++++++++--- adapter/lua_test.go | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index e44ca70..1a845c4 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -3,13 +3,15 @@ package adapter import ( "encoding/base64" "fmt" + "sync/atomic" "github.com/aarzilli/golua/lua" ) // TODO: aSec sync is enough? type Lua struct { - state *lua.State + state *lua.State + packetCounter atomic.Int64 } type LuaParams struct { @@ -36,12 +38,15 @@ func (l *Lua) Close() { l.state.Close() } -func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) { +func (l *Lua) Generate( + msgType int64, + data []byte, +) ([]byte, error) { l.state.GetGlobal("d_gen") l.state.PushInteger(msgType) l.state.PushBytes(data) - l.state.PushInteger(counter) + l.state.PushInteger(l.packetCounter.Add(1)) if err := l.state.Call(3, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 828b555..c7027d0 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) { t.Run("", func(t *testing.T) { l := newLua() defer l.Close() - got, err := l.Generate(1, []byte("test"), 10) + got, err := l.Generate(1, []byte("test")) if err != nil { t.Errorf( "Lua.Generate() error = %v, wantErr %v", @@ -36,7 +36,7 @@ func TestLua_Generate(t *testing.T) { return } - want := "10headertest" + want := "1headertest" if string(got) != want { t.Errorf("Lua.Generate() = %v, want %v", string(got), want) } From 6b0bbcc75cb88c1cf1370ca40b8907e378b205e8 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:36:53 +0100 Subject: [PATCH 07/15] migrate awg specific code to struct --- device/device.go | 65 ++++++++++++++++++++----------------- device/junk_creator.go | 10 +++--- device/junk_creator_test.go | 8 ++--- device/noise-protocol.go | 20 ++++++------ device/receive.go | 12 +++---- device/send.go | 34 +++++++++---------- device/uapi.go | 36 ++++++++++---------- 7 files changed, 95 insertions(+), 90 deletions(-) diff --git a/device/device.go b/device/device.go index aaddcbc..f375663 100644 --- a/device/device.go +++ b/device/device.go @@ -92,14 +92,17 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger + + awg awgType +} +type awgType struct { isASecOn abool.AtomicBool aSecMux sync.RWMutex aSecCfg aSecCfgType junkCreator junkCreator luaAdapter *adapter.Lua - packetCounter atomic.Int64 } type aSecCfgType struct { @@ -431,9 +434,9 @@ func (device *Device) Close() { device.resetProtocol() - if device.luaAdapter != nil { - device.luaAdapter.Close() - device.luaAdapter = nil + if device.awg.luaAdapter != nil { + device.awg.luaAdapter.Close() + device.awg.luaAdapter = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -581,7 +584,7 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.isASecOn.IsSet() + return device.awg.isASecOn.IsSet() } func (device *Device) resetProtocol() { @@ -594,36 +597,36 @@ func (device *Device) resetProtocol() { func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if !tempASecCfg.isSet { - return err + return nil } isASecOn := false - device.aSecMux.Lock() + device.awg.aSecMux.Lock() if tempASecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ) } - device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount + device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount if tempASecCfg.junkPacketCount != 0 { isASecOn = true } - device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize + device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize if tempASecCfg.junkPacketMinSize != 0 { isASecOn = true } - if device.aSecCfg.junkPacketCount > 0 && + if device.awg.aSecCfg.junkPacketCount > 0 && tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { tempASecCfg.junkPacketMaxSize++ // to make rand gen work } if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.aSecCfg.junkPacketMinSize = 0 - device.aSecCfg.junkPacketMaxSize = 1 + device.awg.aSecCfg.junkPacketMinSize = 0 + device.awg.aSecCfg.junkPacketMaxSize = 1 if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, @@ -658,7 +661,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize } if tempASecCfg.junkPacketMaxSize != 0 { @@ -683,7 +686,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize } if tempASecCfg.initPacketJunkSize != 0 { @@ -708,7 +711,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { ) } } else { - device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize } if tempASecCfg.responsePacketJunkSize != 0 { @@ -718,8 +721,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.initPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader - MessageInitiationType = device.aSecCfg.initPacketMagicHeader + device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader + MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType @@ -728,8 +731,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.responsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader - MessageResponseType = device.aSecCfg.responsePacketMagicHeader + device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader + MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType @@ -738,8 +741,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.underloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader + MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType @@ -748,8 +751,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if tempASecCfg.transportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader - MessageTransportType = device.aSecCfg.transportPacketMagicHeader + device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader + MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -785,8 +788,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } } - newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize + newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize + newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize if newInitSize == newResponseSize { if err != nil { @@ -814,16 +817,18 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, + MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, + MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, MessageCookieReplyType: 0, MessageTransportType: 0, } } - device.isASecOn.SetTo(isASecOn) - device.junkCreator, err = NewJunkCreator(device) - device.aSecMux.Unlock() + device.awg.isASecOn.SetTo(isASecOn) + if device.awg.isASecOn.IsSet() { + device.awg.junkCreator, err = NewJunkCreator(device) + } + device.awg.aSecMux.Unlock() return err } diff --git a/device/junk_creator.go b/device/junk_creator.go index 85a5bbc..2bfb197 100644 --- a/device/junk_creator.go +++ b/device/junk_creator.go @@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) { // Should be called with aSecMux RLocked func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { + if jc.device.awg.aSecCfg.junkPacketCount == 0 { return nil, nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount) + for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { @@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) { func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.device.awg.aSecCfg.junkPacketMinSize } // Should be called with aSecMux RLocked diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go index 6f63360..96fa6d3 100644 --- a/device/junk_creator_test.go +++ b/device/junk_creator_test.go @@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { } for range [30]struct{}{} { t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || + got > jc.device.awg.aSecCfg.junkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.device.awg.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize, ) } }) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index d818238..2dc0d93 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() // lookup handshake by receiver diff --git a/device/receive.go b/device/receive.go index e790048..70801ef 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -138,8 +138,8 @@ func (device *Device) RoutineReceiveIncoming( // check size of packet packet := bufsArrs[i][:size] - if device.luaAdapter != nil { - packet, err = device.luaAdapter.Parse(packet) + if device.awg.luaAdapter != nil { + packet, err = device.awg.luaAdapter.Parse(packet) if err != nil { device.log.Verbosef("Couldn't parse message; reason: %v", err) continue @@ -251,7 +251,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -310,7 +310,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle cookie fields and ratelimiting @@ -462,7 +462,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index ac05313..297f19e 100644 --- a/device/send.go +++ b/device/send.go @@ -131,9 +131,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { // so only packet processed for cookie generation var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets(peer) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RLock() + junks, err := peer.device.awg.junkCreator.createJunkPackets(peer) + peer.device.awg.aSecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -153,19 +153,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -215,19 +215,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -548,9 +548,9 @@ func calculatePaddingSize(packetSize, mtu int) int { } func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { - if device.luaAdapter != nil { + if device.awg.luaAdapter != nil { var err error - packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1)) + packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err diff --git a/device/uapi.go b/device/uapi.go index 777bdda..1bce39b 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -98,32 +98,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.awg.aSecCfg.junkPacketCount != 0 { + sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.aSecCfg.junkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.aSecCfg.junkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.aSecCfg.initPacketJunkSize != 0 { + sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.aSecCfg.responsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.aSecCfg.initPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.aSecCfg.responsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.aSecCfg.transportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) } } From f4bc11733deb06a12cec3f7f345d4da281847e37 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:50:24 +0100 Subject: [PATCH 08/15] add lua codec parsing --- adapter/lua.go | 7 ++++- device/device.go | 75 ++++++++++++++++++++++++------------------------ device/uapi.go | 59 ++++++++++++++++++++++--------------- 3 files changed, 80 insertions(+), 61 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 1a845c4..371878d 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -12,6 +12,7 @@ import ( type Lua struct { state *lua.State packetCounter atomic.Int64 + base64LuaCode string } type LuaParams struct { @@ -31,7 +32,7 @@ func NewLua(params LuaParams) (*Lua, error) { if err := state.DoString(string(luaCode)); err != nil { return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } - return &Lua{state: state}, nil + return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil } func (l *Lua) Close() { @@ -71,3 +72,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { return result, nil } + +func (l *Lua) Base64LuaCode() string { + return l.base64LuaCode +} diff --git a/device/device.go b/device/device.go index f375663..b42ddc0 100644 --- a/device/device.go +++ b/device/device.go @@ -595,43 +595,43 @@ func (device *Device) resetProtocol() { MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { - if !tempASecCfg.isSet { +func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { + if !tempAwgType.aSecCfg.isSet { return nil } isASecOn := false device.awg.aSecMux.Lock() - if tempASecCfg.junkPacketCount < 0 { + if tempAwgType.aSecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ) } - device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount - if tempASecCfg.junkPacketCount != 0 { + device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount + if tempAwgType.aSecCfg.junkPacketCount != 0 { isASecOn = true } - device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize - if tempASecCfg.junkPacketMinSize != 0 { + device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize + if tempAwgType.aSecCfg.junkPacketMinSize != 0 { isASecOn = true } if device.awg.aSecCfg.junkPacketCount > 0 && - tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize { - tempASecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work } - if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { + if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { device.awg.aSecCfg.junkPacketMinSize = 0 device.awg.aSecCfg.junkPacketMaxSize = 1 if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempASecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMaxSize, MaxSegmentSize, err, ) @@ -639,41 +639,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempASecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMaxSize, MaxSegmentSize, ) } - } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { + } else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwgType.aSecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMinSize, err, ) } else { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwgType.aSecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMinSize, ) } } else { - device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize } - if tempASecCfg.junkPacketMaxSize != 0 { + if tempAwgType.aSecCfg.junkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { + if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.initPacketJunkSize, + tempAwgType.aSecCfg.initPacketJunkSize, MaxSegmentSize, err, ) @@ -681,24 +681,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.initPacketJunkSize, + tempAwgType.aSecCfg.initPacketJunkSize, MaxSegmentSize, ) } } else { - device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize } - if tempASecCfg.initPacketJunkSize != 0 { + if tempAwgType.aSecCfg.initPacketJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { + if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.responsePacketJunkSize, + tempAwgType.aSecCfg.responsePacketJunkSize, MaxSegmentSize, err, ) @@ -706,52 +706,52 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.responsePacketJunkSize, + tempAwgType.aSecCfg.responsePacketJunkSize, MaxSegmentSize, ) } } else { - device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize } - if tempASecCfg.responsePacketJunkSize != 0 { + if tempAwgType.aSecCfg.responsePacketJunkSize != 0 { isASecOn = true } - if tempASecCfg.initPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.initPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader + device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType } - if tempASecCfg.responsePacketMagicHeader > 4 { + if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader + device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType } - if tempASecCfg.underloadPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader + device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempASecCfg.transportPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader + device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") @@ -828,6 +828,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if device.awg.isASecOn.IsSet() { device.awg.junkCreator, err = NewJunkCreator(device) } + device.awg.luaAdapter = tempAwgType.luaAdapter device.awg.aSecMux.Unlock() return err diff --git a/device/uapi.go b/device/uapi.go index 1bce39b..a2bb21b 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/amnezia-vpn/amneziawg-go/adapter" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -97,6 +98,9 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } + if device.awg.luaAdapter != nil { + sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode()) + } if device.isAdvancedSecurityOn() { if device.awg.aSecCfg.junkPacketCount != 0 { sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) @@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwgTpe := awgType{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwgTpe) if err != nil { return err } @@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwgTpe) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwgTpe) if err != nil { return err } @@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error { switch key { case "private_key": var sk NoisePrivateKey @@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() + case "lua_codec": + device.log.Verbosef("UAPI: Updating lua_codec") + var err error + tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{ + Base64LuaCode: value, + }) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err) + } case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount + tempAwgTpe.aSecCfg.isSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize + tempAwgTpe.aSecCfg.isSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + tempAwgTpe.aSecCfg.isSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize + tempAwgTpe.aSecCfg.isSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempAwgTpe.aSecCfg.isSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) From 1e532c1e71f36e9cc1ce0fa7df2f8447a1d08f02 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 18:26:57 +0100 Subject: [PATCH 09/15] awg-2 working with identity generator --- adapter/lua.go | 1 + adapter/lua_test.go | 6 +++--- device/device.go | 14 +++++++++++++- device/device_test.go | 9 ++++++--- device/receive.go | 29 +++++++++++++++++++++++++++-- device/send.go | 17 ++++------------- 6 files changed, 54 insertions(+), 22 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 371878d..838ff84 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -69,6 +69,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { result := l.state.ToBytes(-1) l.state.Pop(1) + // copy(data, result) return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index c7027d0..1a3d039 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -13,11 +13,11 @@ func newLua() *Lua { end function d_parse(data) - local header = "10header" + local header = "1header" return string.sub(data, #header+1) end */ - Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK", + Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjFoZWFkZXIiCglyZXR1cm4gc3RyaW5nLnN1YihkYXRhLCAjaGVhZGVyKzEpCmVuZAo=", }) return lua } @@ -47,7 +47,7 @@ func TestLua_Parse(t *testing.T) { t.Run("", func(t *testing.T) { l := newLua() defer l.Close() - got, err := l.Parse([]byte("10headertest")) + got, err := l.Parse([]byte("1headertest")) if err != nil { t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil) return diff --git a/device/device.go b/device/device.go index b42ddc0..b16b460 100644 --- a/device/device.go +++ b/device/device.go @@ -92,7 +92,7 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger - + awg awgType } @@ -833,3 +833,15 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { return err } + +func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { + if device.awg.luaAdapter != nil { + var err error + packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) + if err != nil { + device.log.Errorf("%v - Failed to run codec generate: %v", device, err) + return nil, err + } + } + return packet, nil +} diff --git a/device/device_test.go b/device/device_test.go index e904f26..63f20df 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -107,6 +107,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", + "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=", "jc", "5", "jmin", "500", "jmax", "1000", @@ -114,8 +115,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "s2", "40", "h1", "123456", "h2", "67543", - "h4", "32345", "h3", "123123", + "h4", "32345", "public_key", hex.EncodeToString(pub2[:]), "protocol_version", "1", "replace_allowed_ips", "true", @@ -129,6 +130,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", + "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=", "jc", "5", "jmin", "500", "jmax", "1000", @@ -136,8 +138,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "s2", "40", "h1", "123456", "h2", "67543", - "h4", "32345", "h3", "123123", + "h4", "32345", "public_key", hex.EncodeToString(pub1[:]), "protocol_version", "1", "replace_allowed_ips", "true", @@ -192,6 +194,7 @@ func (pair *testPair) Send( var err error select { case msgRecv := <-p0.tun.Inbound: + fmt.Printf("len(%d) msg: %x\nlen(%d) rec: %x\n", len(msg), msg, len(msgRecv), msgRecv) if !bytes.Equal(msg, msgRecv) { err = fmt.Errorf("%s did not transit correctly", ping) } @@ -275,7 +278,7 @@ func TestTwoDevicePing(t *testing.T) { } // Run test with -race=false to avoid the race for setting the default msgTypes 2 times -func TestTwoDevicePingASecurity(t *testing.T) { +func TestASecurityTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true, true) t.Run("ping 1.0.0.1", func(t *testing.T) { diff --git a/device/receive.go b/device/receive.go index 70801ef..f8551ae 100644 --- a/device/receive.go +++ b/device/receive.go @@ -9,9 +9,11 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "net" "sync" "time" + "unsafe" "github.com/amnezia-vpn/amneziawg-go/conn" "golang.org/x/crypto/chacha20poly1305" @@ -138,8 +140,24 @@ func (device *Device) RoutineReceiveIncoming( // check size of packet packet := bufsArrs[i][:size] + fmt.Printf("bufsArrs size: %d\n%.100x\n", size, bufsArrs[i]) + fmt.Printf("packet before: %x\n", packet) if device.awg.luaAdapter != nil { - packet, err = device.awg.luaAdapter.Parse(packet) + ptr:= unsafe.Pointer(bufsArrs[i]) // Get pointer to the array + slicePtr:= (*byte)(ptr) // Type conversion to the array type + + realPacket, err := device.awg.luaAdapter.Parse(packet) + // Copy data from newSlice to the memory pointed to by slicedPtr + newSliceLen:= len(realPacket) + for j:= 0; j < newSliceLen; j++ { + *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(slicePtr)) + uintptr(j))) = realPacket[j] + } + fmt.Printf("packet after: %x\n", packet) + fmt.Printf("bufsArs after size: %d\n%.100x\n", size, bufsArrs[i]) + // diff := size - len(packet) + // bufsArrs[i][:len(packet)] = bufsArrs[i][diff:len(packet)] + size = len(packet) + fmt.Println("after size: ", size) if err != nil { device.log.Verbosef("Couldn't parse message; reason: %v", err) continue @@ -151,7 +169,7 @@ func (device *Device) RoutineReceiveIncoming( junkSize := msgTypeToJunkSize[assumedMsgType] // transport size can align with other header types; // making sure we have the right msgType - msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4]) + msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4]) if msgType == assumedMsgType { packet = packet[junkSize:] } else { @@ -285,15 +303,18 @@ func (device *Device) RoutineDecryption(id int) { elem.counter = binary.LittleEndian.Uint64(counter) // copy counter to nonce binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + fmt.Printf("before decrypt: %x\n", elem.packet) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], content, nil, ) + if err != nil { elem.packet = nil } + fmt.Printf("decrypt: %x\n", elem.packet) } elemsContainer.Unlock() } @@ -551,10 +572,13 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } + fmt.Printf("bufs packet: %x\n", elem.packet) + fmt.Printf("bufs packet: %x\n", elem.buffer[len(elem.packet)+1:MessageTransportOffsetContent+len(elem.packet)]) bufs = append( bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], ) + fmt.Printf("bufs before send: %.100x\n", elem.buffer) } peer.rxBytes.Add(rxBytesLen) @@ -568,6 +592,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersDataReceived() } if len(bufs) > 0 { + fmt.Printf("bufs: %x\n", bufs) _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) diff --git a/device/send.go b/device/send.go index 297f19e..85ad169 100644 --- a/device/send.go +++ b/device/send.go @@ -9,6 +9,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "net" "os" "sync" @@ -547,18 +548,6 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } -func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { - if device.awg.luaAdapter != nil { - var err error - packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) - if err != nil { - device.log.Errorf("%v - Failed to run codec generate: %v", device, err) - return nil, err - } - } - return packet, nil -} - /* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * @@ -603,11 +592,12 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - + fmt.Printf("msg: %x\n", elem.packet) var err error if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { continue } + fmt.Printf("msgmsg: %x\n", elem.packet) } elemsContainer.Unlock() } @@ -662,6 +652,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersDataSent() } for _, elem := range elemsContainer.elems { + fmt.Printf("send buffer: %.200x\n", elem.buffer) device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } From 068e1465387fff82501450532587ee58be60dff0 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 18:34:36 +0100 Subject: [PATCH 10/15] working codec --- adapter/lua.go | 1 - device/device_test.go | 5 ++--- device/receive.go | 35 +++++++++++------------------------ device/send.go | 4 ---- 4 files changed, 13 insertions(+), 32 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 838ff84..371878d 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -69,7 +69,6 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { result := l.state.ToBytes(-1) l.state.Pop(1) - // copy(data, result) return result, nil } diff --git a/device/device_test.go b/device/device_test.go index 63f20df..367f097 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -107,7 +107,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=", + "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", "jc", "5", "jmin", "500", "jmax", "1000", @@ -130,7 +130,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=", + "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", "jc", "5", "jmin", "500", "jmax", "1000", @@ -194,7 +194,6 @@ func (pair *testPair) Send( var err error select { case msgRecv := <-p0.tun.Inbound: - fmt.Printf("len(%d) msg: %x\nlen(%d) rec: %x\n", len(msg), msg, len(msgRecv), msgRecv) if !bytes.Equal(msg, msgRecv) { err = fmt.Errorf("%s did not transit correctly", ping) } diff --git a/device/receive.go b/device/receive.go index f8551ae..4ac5783 100644 --- a/device/receive.go +++ b/device/receive.go @@ -9,7 +9,6 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "net" "sync" "time" @@ -138,28 +137,22 @@ func (device *Device) RoutineReceiveIncoming( continue } - // check size of packet packet := bufsArrs[i][:size] - fmt.Printf("bufsArrs size: %d\n%.100x\n", size, bufsArrs[i]) - fmt.Printf("packet before: %x\n", packet) if device.awg.luaAdapter != nil { - ptr:= unsafe.Pointer(bufsArrs[i]) // Get pointer to the array - slicePtr:= (*byte)(ptr) // Type conversion to the array type - realPacket, err := device.awg.luaAdapter.Parse(packet) - // Copy data from newSlice to the memory pointed to by slicedPtr - newSliceLen:= len(realPacket) - for j:= 0; j < newSliceLen; j++ { - *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(slicePtr)) + uintptr(j))) = realPacket[j] + + packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array + // Copy data from realPacket to the memory pointed to by bufsArrs[i] + size = len(realPacket) + for j := 0; j < size; j++ { + *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j] } - fmt.Printf("packet after: %x\n", packet) - fmt.Printf("bufsArs after size: %d\n%.100x\n", size, bufsArrs[i]) - // diff := size - len(packet) - // bufsArrs[i][:len(packet)] = bufsArrs[i][diff:len(packet)] - size = len(packet) - fmt.Println("after size: ", size) + packet = bufsArrs[i][:size] if err != nil { - device.log.Verbosef("Couldn't parse message; reason: %v", err) + device.log.Verbosef( + "Couldn't parse message; reason: %v", + err, + ) continue } } @@ -303,7 +296,6 @@ func (device *Device) RoutineDecryption(id int) { elem.counter = binary.LittleEndian.Uint64(counter) // copy counter to nonce binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - fmt.Printf("before decrypt: %x\n", elem.packet) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], @@ -314,7 +306,6 @@ func (device *Device) RoutineDecryption(id int) { if err != nil { elem.packet = nil } - fmt.Printf("decrypt: %x\n", elem.packet) } elemsContainer.Unlock() } @@ -572,13 +563,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - fmt.Printf("bufs packet: %x\n", elem.packet) - fmt.Printf("bufs packet: %x\n", elem.buffer[len(elem.packet)+1:MessageTransportOffsetContent+len(elem.packet)]) bufs = append( bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], ) - fmt.Printf("bufs before send: %.100x\n", elem.buffer) } peer.rxBytes.Add(rxBytesLen) @@ -592,7 +580,6 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersDataReceived() } if len(bufs) > 0 { - fmt.Printf("bufs: %x\n", bufs) _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) diff --git a/device/send.go b/device/send.go index 85ad169..5acd81c 100644 --- a/device/send.go +++ b/device/send.go @@ -9,7 +9,6 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "net" "os" "sync" @@ -592,12 +591,10 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - fmt.Printf("msg: %x\n", elem.packet) var err error if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { continue } - fmt.Printf("msgmsg: %x\n", elem.packet) } elemsContainer.Unlock() } @@ -652,7 +649,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersDataSent() } for _, elem := range elemsContainer.elems { - fmt.Printf("send buffer: %.200x\n", elem.buffer) device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } From 0dcd528694ad8c9babd807cb38bce2aea64d3126 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 10 Feb 2025 10:05:42 +0100 Subject: [PATCH 11/15] rename vars and complete testing --- adapter/lua.go | 55 +++++++++++++++++++++++++++------------- adapter/lua_test.go | 22 ++++++++++++++++ device/device.go | 26 +++++++++++-------- device/device_test.go | 15 +++++++++-- device/noise-protocol.go | 20 +++++++-------- device/receive.go | 20 +++++---------- device/send.go | 25 +++++++++--------- device/uapi.go | 6 ++--- 8 files changed, 121 insertions(+), 68 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 371878d..4677345 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -8,9 +8,9 @@ import ( "github.com/aarzilli/golua/lua" ) -// TODO: aSec sync is enough? type Lua struct { - state *lua.State + generateState *lua.State + parseState *lua.State packetCounter atomic.Int64 base64LuaCode string } @@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) { if err != nil { return nil, err } - fmt.Println(string(luaCode)) + strLuaCode := string(luaCode) + + generateState, err := initState(strLuaCode) + if err != nil { + return nil, err + } + parseState, err := initState(strLuaCode) + if err != nil { + return nil, err + } + + return &Lua{ + generateState: generateState, + parseState: parseState, + base64LuaCode: params.Base64LuaCode, + }, nil +} + +func initState(luaCode string) (*lua.State, error) { state := lua.NewState() state.OpenLibs() if err := state.DoString(string(luaCode)); err != nil { return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } - return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil + return state, nil } func (l *Lua) Close() { - l.state.Close() + l.generateState.Close() + l.parseState.Close() } +// Only thread safe if used by wg packet creation which happens independably func (l *Lua) Generate( msgType int64, data []byte, ) ([]byte, error) { - l.state.GetGlobal("d_gen") + l.generateState.GetGlobal("d_gen") - l.state.PushInteger(msgType) - l.state.PushBytes(data) - l.state.PushInteger(l.packetCounter.Add(1)) + l.generateState.PushInteger(msgType) + l.generateState.PushBytes(data) + l.generateState.PushInteger(l.packetCounter.Add(1)) - if err := l.state.Call(3, 1); err != nil { + if err := l.generateState.Call(3, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := l.state.ToBytes(-1) - l.state.Pop(1) + result := l.generateState.ToBytes(-1) + l.generateState.Pop(1) return result, nil } +// Only thread safe if used by wg packet receive which happens independably func (l *Lua) Parse(data []byte) ([]byte, error) { - l.state.GetGlobal("d_parse") + l.parseState.GetGlobal("d_parse") - l.state.PushBytes(data) - if err := l.state.Call(1, 1); err != nil { + l.parseState.PushBytes(data) + if err := l.parseState.Call(1, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := l.state.ToBytes(-1) - l.state.Pop(1) + result := l.parseState.ToBytes(-1) + l.parseState.Pop(1) return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 1a3d039..88389f5 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) { } }) } + + +var R []byte +var l = newLua() +func BenchmarkLuaGenerate(b *testing.B) { + for i := 0; i < b.N; i++ { var err error + R, err = l.Generate(1, []byte("test")) + if err != nil { + return + } + } +} + +func BenchmarkLuaParse(b *testing.B) { + for i := 0; i < b.N; i++ { + var err error + R, err = l.Parse([]byte("1headertest")) + if err != nil { + return + } + } +} diff --git a/device/device.go b/device/device.go index b16b460..8a96ffd 100644 --- a/device/device.go +++ b/device/device.go @@ -98,11 +98,11 @@ type Device struct { type awgType struct { isASecOn abool.AtomicBool - aSecMux sync.RWMutex + mutex sync.RWMutex aSecCfg aSecCfgType junkCreator junkCreator - luaAdapter *adapter.Lua + codec *adapter.Lua } type aSecCfgType struct { @@ -434,9 +434,9 @@ func (device *Device) Close() { device.resetProtocol() - if device.awg.luaAdapter != nil { - device.awg.luaAdapter.Close() - device.awg.luaAdapter = nil + if device.awg.codec != nil { + device.awg.codec.Close() + device.awg.codec = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { } isASecOn := false - device.awg.aSecMux.Lock() + device.awg.mutex.Lock() if tempAwgType.aSecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, @@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { if device.awg.isASecOn.IsSet() { device.awg.junkCreator, err = NewJunkCreator(device) } - device.awg.luaAdapter = tempAwgType.luaAdapter - device.awg.aSecMux.Unlock() + device.awg.codec = tempAwgType.codec + device.awg.mutex.Unlock() return err } -func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { - if device.awg.luaAdapter != nil { +func (device *Device) isCodecActive() bool { + return device.awg.codec != nil +} + +func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) { + if device.isCodecActive() { var err error - packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet) + packet, err = device.awg.codec.Generate(int64(msgType),packet) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err diff --git a/device/device_test.go b/device/device_test.go index 367f097..0f42c46 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { } pub1, pub2 := key1.publicKey(), key2.publicKey() + /* + head = "headheadhead" + tail = "tailtailtail" + function d_gen(msg_type, data, counter) + return head .. data .. tail + end + + function d_parse(data) + return string.match(data, head.. "(.-)".. tail) + end + */ cfgs[0] = uapiCfg( "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", + "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK", "jc", "5", "jmin", "500", "jmax", "1000", @@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", - "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==", + "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK", "jc", "5", "jmin", "500", "jmax", "1000", diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 2dc0d93..75352c0 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() if msg.Type != MessageInitiationType { - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() msg.Type = MessageResponseType - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() if msg.Type != MessageResponseType { - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() // lookup handshake by receiver diff --git a/device/receive.go b/device/receive.go index 4ac5783..3df269e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -12,7 +12,6 @@ import ( "net" "sync" "time" - "unsafe" "github.com/amnezia-vpn/amneziawg-go/conn" "golang.org/x/crypto/chacha20poly1305" @@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming( } packet := bufsArrs[i][:size] - if device.awg.luaAdapter != nil { - realPacket, err := device.awg.luaAdapter.Parse(packet) - - packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array - // Copy data from realPacket to the memory pointed to by bufsArrs[i] + if device.isCodecActive() { + realPacket, err := device.awg.codec.Parse(packet) + copy(packet, realPacket) size = len(realPacket) - for j := 0; j < size; j++ { - *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j] - } packet = bufsArrs[i][:size] if err != nil { device.log.Verbosef( @@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.awg.aSecMux.RLock() + device.awg.mutex.RLock() // handle cookie fields and ratelimiting @@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.awg.aSecMux.RUnlock() + device.awg.mutex.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 5acd81c..d6352e9 100644 --- a/device/send.go +++ b/device/send.go @@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { ) return err } + var sendBuffer [][]byte // so only packet processed for cookie generation var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() junks, err := peer.device.awg.junkCreator.createJunkPackets(peer) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() } var buf [MessageInitiationSize]byte @@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil { return err } @@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() + peer.device.awg.mutex.RLock() if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.mutex.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil { return err } @@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie( writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) packet := writer.Bytes() - if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil { + if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil { return err } @@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) { nil, ) var err error - if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { + if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil { continue } } diff --git a/device/uapi.go b/device/uapi.go index a2bb21b..850cc6c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - if device.awg.luaAdapter != nil { - sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode()) + if device.awg.codec != nil { + sendf("lua_codec=%s", device.awg.codec.Base64LuaCode()) } if device.isAdvancedSecurityOn() { if device.awg.aSecCfg.junkPacketCount != 0 { @@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e case "lua_codec": device.log.Verbosef("UAPI: Updating lua_codec") var err error - tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{ + tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{ Base64LuaCode: value, }) if err != nil { From ae81fcbd7ab0ae44d68ebcc2c80b105aa948e3d1 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 10 Feb 2025 10:40:34 +0100 Subject: [PATCH 12/15] delete adapter makefile --- adapter/Makefile | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 adapter/Makefile diff --git a/adapter/Makefile b/adapter/Makefile deleted file mode 100644 index 41bc7aa..0000000 --- a/adapter/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -all: - go run -tags lua54 lua.go From d03bddd7c77b585555def6d78ecbe45e0b9c876a Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 10 Feb 2025 10:43:18 +0100 Subject: [PATCH 13/15] add luajit tag for lua codec --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 7a88647..979681a 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ generate-version-and-build: @$(MAKE) amneziawg-go amneziawg-go: $(wildcard *.go) $(wildcard */*.go) - go build -v -o "$@" + go build -tags luajit -v -o "$@" install: amneziawg-go @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go" From 02462a0e2fc02e78a2282c1387ba0cafc5f11040 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Tue, 11 Feb 2025 08:08:22 +0100 Subject: [PATCH 14/15] always assign lua_codec in postConfig --- device/device.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/device/device.go b/device/device.go index 8a96ffd..6a83bf9 100644 --- a/device/device.go +++ b/device/device.go @@ -98,7 +98,7 @@ type Device struct { type awgType struct { isASecOn abool.AtomicBool - mutex sync.RWMutex + mutex sync.RWMutex aSecCfg aSecCfgType junkCreator junkCreator @@ -596,6 +596,8 @@ func (device *Device) resetProtocol() { } func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { + device.awg.codec = tempAwgType.codec + if !tempAwgType.aSecCfg.isSet { return nil } @@ -828,7 +830,6 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { if device.awg.isASecOn.IsSet() { device.awg.junkCreator, err = NewJunkCreator(device) } - device.awg.codec = tempAwgType.codec device.awg.mutex.Unlock() return err From c486c2eb5d46a1cd4df09b7818b565dca064561a Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Tue, 11 Feb 2025 08:40:08 +0100 Subject: [PATCH 15/15] relocate adapter to internal structure --- device/device.go | 2 +- {adapter => device/internal/adapter}/lua.go | 1 + {adapter => device/internal/adapter}/lua_test.go | 0 device/uapi.go | 2 +- 4 files changed, 3 insertions(+), 2 deletions(-) rename {adapter => device/internal/adapter}/lua.go (98%) rename {adapter => device/internal/adapter}/lua_test.go (100%) diff --git a/device/device.go b/device/device.go index 6a83bf9..1e28e8f 100644 --- a/device/device.go +++ b/device/device.go @@ -11,8 +11,8 @@ import ( "sync/atomic" "time" - "github.com/amnezia-vpn/amneziawg-go/adapter" "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device/internal/adapter" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" diff --git a/adapter/lua.go b/device/internal/adapter/lua.go similarity index 98% rename from adapter/lua.go rename to device/internal/adapter/lua.go index 4677345..9c5e07d 100644 --- a/adapter/lua.go +++ b/device/internal/adapter/lua.go @@ -26,6 +26,7 @@ func NewLua(params LuaParams) (*Lua, error) { } strLuaCode := string(luaCode) + // fmt.Println(strLuaCode) generateState, err := initState(strLuaCode) if err != nil { diff --git a/adapter/lua_test.go b/device/internal/adapter/lua_test.go similarity index 100% rename from adapter/lua_test.go rename to device/internal/adapter/lua_test.go diff --git a/device/uapi.go b/device/uapi.go index 850cc6c..7bd5209 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,7 @@ import ( "sync" "time" - "github.com/amnezia-vpn/amneziawg-go/adapter" + "github.com/amnezia-vpn/amneziawg-go/device/internal/adapter" "github.com/amnezia-vpn/amneziawg-go/ipc" )