From 44217caa0d41884cc9a21b447fbfa43cceb1d2eb Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 8 Feb 2025 13:37:52 +0100
Subject: [PATCH 01/15] successfully call lua function from Go
---
adapter/Makefile | 2 ++
adapter/lua.go | 45 +++++++++++++++++++++++++++++++++++++++++++++
go.mod | 1 +
go.sum | 2 ++
4 files changed, 50 insertions(+)
create mode 100644 adapter/Makefile
create mode 100644 adapter/lua.go
diff --git a/adapter/Makefile b/adapter/Makefile
new file mode 100644
index 0000000..41bc7aa
--- /dev/null
+++ b/adapter/Makefile
@@ -0,0 +1,2 @@
+all:
+ go run -tags lua54 lua.go
diff --git a/adapter/lua.go b/adapter/lua.go
new file mode 100644
index 0000000..8c56d4a
--- /dev/null
+++ b/adapter/lua.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "encoding/base64"
+ "fmt"
+
+ "github.com/aarzilli/golua/lua"
+)
+
+func main() {
+ // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==`
+ // only d_gen
+ // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
+ luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
+ sDec, _ := base64.StdEncoding.DecodeString(luaB64)
+ fmt.Println(string(sDec))
+ luaCode := sDec
+ L := lua.NewState()
+ L.OpenLibs()
+ defer L.Close()
+
+ // Load and execute the Lua code
+ if err := L.DoString(string(luaCode)); err != nil {
+ fmt.Printf("Error loading Lua code: %v\n", err)
+ return
+ }
+
+ // Push the function onto the stack
+ L.GetGlobal("D_gen")
+
+ // Push the argument
+ L.PushString("data")
+
+ if err := L.Call(1, 1); err != nil {
+ fmt.Printf("Error calling Lua function: %v\n", err)
+ return
+ }
+
+ result := L.ToString(-1)
+ L.Pop(1)
+
+ // Print the result
+ // fmt.Printf("Result: %x\n", []byte(result))
+ fmt.Printf("Result: %s\n", result)
+}
diff --git a/go.mod b/go.mod
index b03acb0..8511642 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
)
require (
+ github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
diff --git a/go.sum b/go.sum
index 7b53725..119ba05 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e h1:ibMKBskN7uMCz9TJgfaIVYVdPyckXm0UjFDRSNV7XB0=
+github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e/go.mod h1:hMjfaJVSqVnxenMlsxrq3Ni+vrm9Hs64tU4M7dhUoO4=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
From 55f4715a5055fd1aa84d640bb0574f9e64c5ba37 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 8 Feb 2025 14:16:03 +0100
Subject: [PATCH 02/15] create lua adapter and test
---
adapter/lua.go | 79 +++++++++++++++++++++++++++++++--------------
adapter/lua_test.go | 60 ++++++++++++++++++++++++++++++++++
go.mod | 2 +-
3 files changed, 115 insertions(+), 26 deletions(-)
create mode 100644 adapter/lua_test.go
diff --git a/adapter/lua.go b/adapter/lua.go
index 8c56d4a..a6aceb1 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -1,4 +1,4 @@
-package main
+package adapter
import (
"encoding/base64"
@@ -7,39 +7,68 @@ import (
"github.com/aarzilli/golua/lua"
)
-func main() {
- // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==`
- // only d_gen
- // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
- luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
- sDec, _ := base64.StdEncoding.DecodeString(luaB64)
- fmt.Println(string(sDec))
- luaCode := sDec
- L := lua.NewState()
- L.OpenLibs()
- defer L.Close()
+// TODO: aSec sync is enough?
+type Lua struct {
+ state *lua.State
+}
+
+type LuaParams struct {
+ LuaCode64 string
+}
+
+func NewLua(params LuaParams) (*Lua, error) {
+ luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Println(string(luaCode))
+
+ state := lua.NewState()
+ state.OpenLibs()
// Load and execute the Lua code
- if err := L.DoString(string(luaCode)); err != nil {
- fmt.Printf("Error loading Lua code: %v\n", err)
- return
+ if err := state.DoString(string(luaCode)); err != nil {
+ return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
+ return &Lua{state: state}, nil
+}
+func (l *Lua) Close() {
+ l.state.Close()
+}
+
+func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
// Push the function onto the stack
- L.GetGlobal("D_gen")
+ l.state.GetGlobal("D_gen")
// Push the argument
- L.PushString("data")
+ l.state.PushBytes(data)
+ l.state.PushInteger(counter)
- if err := L.Call(1, 1); err != nil {
- fmt.Printf("Error calling Lua function: %v\n", err)
- return
+ if err := l.state.Call(2, 1); err != nil {
+ return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := L.ToString(-1)
- L.Pop(1)
+ result := l.state.ToBytes(-1)
+ l.state.Pop(1)
- // Print the result
- // fmt.Printf("Result: %x\n", []byte(result))
- fmt.Printf("Result: %s\n", result)
+ fmt.Printf("Result: %s\n", string(result))
+ return result, nil
+}
+
+func (l *Lua) Parse(data []byte) ([]byte, error) {
+ // Push the function onto the stack
+ l.state.GetGlobal("D_parse")
+
+ // Push the argument
+ l.state.PushBytes(data)
+ if err := l.state.Call(1, 1); err != nil {
+ return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
+ }
+
+ result := l.state.ToBytes(-1)
+ l.state.Pop(1)
+
+ fmt.Printf("Result: %s\n", string(result))
+ return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
new file mode 100644
index 0000000..1d35651
--- /dev/null
+++ b/adapter/lua_test.go
@@ -0,0 +1,60 @@
+package adapter
+
+import (
+ "testing"
+)
+
+func newLua() *Lua {
+ lua, _ := NewLua(LuaParams{
+ /*
+ function D_gen(data, counter)
+ local header = "header"
+ return counter .. header .. data
+ end
+
+ function D_parse(data)
+ local header = "10header"
+ return string.sub(data, #header+1)
+ end
+ */
+ LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
+ })
+ return lua
+}
+
+func TestLua_Generate(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ l := newLua()
+ defer l.Close()
+ got, err := l.Generate([]byte("test"), 10)
+ if err != nil {
+ t.Errorf(
+ "Lua.Generate() error = %v, wantErr %v",
+ err,
+ nil,
+ )
+ return
+ }
+
+ want := "10headertest"
+ if string(got) != want {
+ t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
+ }
+ })
+}
+
+func TestLua_Parse(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ l := newLua()
+ defer l.Close()
+ got, err := l.Parse([]byte("10headertest"))
+ if err != nil {
+ t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
+ return
+ }
+ want := "test"
+ if string(got) != want {
+ t.Errorf("Lua.Parse() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/go.mod b/go.mod
index 8511642..435752a 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.23
require (
+ github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.21.0
@@ -12,7 +13,6 @@ require (
)
require (
- github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
From 3015f3ea20573a0eada7372b45c255415736896f Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 8 Feb 2025 14:35:20 +0100
Subject: [PATCH 03/15] change function initials
---
adapter/lua.go | 15 ++++-----------
adapter/lua_test.go | 6 +++---
2 files changed, 7 insertions(+), 14 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index a6aceb1..47ce7b2 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -13,11 +13,11 @@ type Lua struct {
}
type LuaParams struct {
- LuaCode64 string
+ Base64LuaCode string
}
func NewLua(params LuaParams) (*Lua, error) {
- luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64)
+ luaCode, err := base64.StdEncoding.DecodeString(params.Base64LuaCode)
if err != nil {
return nil, err
}
@@ -26,7 +26,6 @@ func NewLua(params LuaParams) (*Lua, error) {
state := lua.NewState()
state.OpenLibs()
- // Load and execute the Lua code
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
@@ -38,10 +37,8 @@ func (l *Lua) Close() {
}
func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
- // Push the function onto the stack
- l.state.GetGlobal("D_gen")
+ l.state.GetGlobal("d_gen")
- // Push the argument
l.state.PushBytes(data)
l.state.PushInteger(counter)
@@ -52,15 +49,12 @@ func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
result := l.state.ToBytes(-1)
l.state.Pop(1)
- fmt.Printf("Result: %s\n", string(result))
return result, nil
}
func (l *Lua) Parse(data []byte) ([]byte, error) {
- // Push the function onto the stack
- l.state.GetGlobal("D_parse")
+ l.state.GetGlobal("d_parse")
- // Push the argument
l.state.PushBytes(data)
if err := l.state.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
@@ -69,6 +63,5 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
result := l.state.ToBytes(-1)
l.state.Pop(1)
- fmt.Printf("Result: %s\n", string(result))
return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 1d35651..08d5420 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -7,17 +7,17 @@ import (
func newLua() *Lua {
lua, _ := NewLua(LuaParams{
/*
- function D_gen(data, counter)
+ function d_gen(data, counter)
local header = "header"
return counter .. header .. data
end
- function D_parse(data)
+ function d_parse(data)
local header = "10header"
return string.sub(data, #header+1)
end
*/
- LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
+ Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
})
return lua
}
From 32470fa04ee4b131ffca0b13d55a6b362e50b782 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 8 Feb 2025 19:00:00 +0100
Subject: [PATCH 04/15] add codec generation/parsing
---
device/device.go | 14 ++++++++++----
device/receive.go | 8 +++++++-
device/send.go | 32 +++++++++++++++++++++++++++++++-
3 files changed, 48 insertions(+), 6 deletions(-)
diff --git a/device/device.go b/device/device.go
index 1bedce3..8178667 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
+ "github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
@@ -92,11 +93,13 @@ type Device struct {
closed chan struct{}
log *Logger
- isASecOn abool.AtomicBool
- aSecMux sync.RWMutex
- aSecCfg aSecCfgType
-
+ isASecOn abool.AtomicBool
+ aSecMux sync.RWMutex
+ aSecCfg aSecCfgType
junkCreator junkCreator
+
+ luaAdapter *adapter.Lua
+ packetCounter atomic.Int64
}
type aSecCfgType struct {
@@ -428,6 +431,9 @@ func (device *Device) Close() {
device.resetProtocol()
+ if device.luaAdapter != nil {
+ device.luaAdapter.Close()
+ }
device.log.Verbosef("Device closed")
close(device.closed)
}
diff --git a/device/receive.go b/device/receive.go
index 66c1a32..e790048 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -137,8 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
}
// check size of packet
-
packet := bufsArrs[i][:size]
+ if device.luaAdapter != nil {
+ packet, err = device.luaAdapter.Parse(packet)
+ if err != nil {
+ device.log.Verbosef("Couldn't parse message; reason: %v", err)
+ continue
+ }
+ }
var msgType uint32
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
diff --git a/device/send.go b/device/send.go
index 5c54d4d..c0b9ad4 100644
--- a/device/send.go
+++ b/device/send.go
@@ -175,6 +175,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
+ if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ return err
+ }
+
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
@@ -233,6 +237,10 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
+ if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ return err
+ }
+
err = peer.BeginSymmetricSession()
if err != nil {
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
@@ -277,8 +285,13 @@ func (device *Device) SendHandshakeCookie(
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
+ packet := writer.Bytes()
+ if packet, err = device.codecPacket(packet); err != nil {
+ return err
+ }
+
// TODO: allocation could be avoided
- device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
+ device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
return nil
}
@@ -534,6 +547,18 @@ func calculatePaddingSize(packetSize, mtu int) int {
return paddedSize - lastUnit
}
+func (device *Device) codecPacket(packet []byte) ([]byte, error) {
+ if device.luaAdapter != nil {
+ var err error
+ packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1))
+ if err != nil {
+ device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
+ return nil, err
+ }
+ }
+ return packet, nil
+}
+
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
@@ -578,6 +603,11 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
+ // TODO: check
+ var err error
+ if elem.packet, err = device.codecPacket(elem.packet); err != nil {
+ continue
+ }
}
elemsContainer.Unlock()
}
From 553060f804a4429356037143a3dacd71d2e74a19 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:21:01 +0100
Subject: [PATCH 05/15] Send msg type to code & add default msg types
---
adapter/lua.go | 5 +++--
adapter/lua_test.go | 20 ++++++++++----------
device/device.go | 17 +++++++++--------
device/noise-protocol.go | 15 +++++++++++----
device/send.go | 14 +++++++-------
5 files changed, 40 insertions(+), 31 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 47ce7b2..e44ca70 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -36,13 +36,14 @@ func (l *Lua) Close() {
l.state.Close()
}
-func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
+func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) {
l.state.GetGlobal("d_gen")
+ l.state.PushInteger(msgType)
l.state.PushBytes(data)
l.state.PushInteger(counter)
- if err := l.state.Call(2, 1); err != nil {
+ if err := l.state.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 08d5420..828b555 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -7,17 +7,17 @@ import (
func newLua() *Lua {
lua, _ := NewLua(LuaParams{
/*
- function d_gen(data, counter)
- local header = "header"
- return counter .. header .. data
- end
+ function d_gen(msg_type, data, counter)
+ local header = "header"
+ return counter .. header .. data
+ end
- function d_parse(data)
- local header = "10header"
- return string.sub(data, #header+1)
- end
+ function d_parse(data)
+ local header = "10header"
+ return string.sub(data, #header+1)
+ end
*/
- Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
+ Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK",
})
return lua
}
@@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
- got, err := l.Generate([]byte("test"), 10)
+ got, err := l.Generate(1, []byte("test"), 10)
if err != nil {
t.Errorf(
"Lua.Generate() error = %v, wantErr %v",
diff --git a/device/device.go b/device/device.go
index 8178667..aaddcbc 100644
--- a/device/device.go
+++ b/device/device.go
@@ -433,6 +433,7 @@ func (device *Device) Close() {
if device.luaAdapter != nil {
device.luaAdapter.Close()
+ device.luaAdapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -585,10 +586,10 @@ func (device *Device) isAdvancedSecurityOn() bool {
func (device *Device) resetProtocol() {
// restore default message type values
- MessageInitiationType = 1
- MessageResponseType = 2
- MessageCookieReplyType = 3
- MessageTransportType = 4
+ MessageInitiationType = DefaultMessageInitiationType
+ MessageResponseType = DefaultMessageResponseType
+ MessageCookieReplyType = DefaultMessageCookieReplyType
+ MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
@@ -721,7 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
- MessageInitiationType = 1
+ MessageInitiationType = DefaultMessageInitiationType
}
if tempASecCfg.responsePacketMagicHeader > 4 {
@@ -731,7 +732,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
- MessageResponseType = 2
+ MessageResponseType = DefaultMessageResponseType
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
@@ -741,7 +742,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
- MessageCookieReplyType = 3
+ MessageCookieReplyType = DefaultMessageCookieReplyType
}
if tempASecCfg.transportPacketMagicHeader > 4 {
@@ -751,7 +752,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
- MessageTransportType = 4
+ MessageTransportType = DefaultMessageTransportType
}
isSameMap := map[uint32]bool{}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 1289249..d818238 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -52,11 +52,18 @@ const (
WGLabelCookie = "cookie--"
)
+const (
+ DefaultMessageInitiationType uint32 = 1
+ DefaultMessageResponseType uint32 = 2
+ DefaultMessageCookieReplyType uint32 = 3
+ DefaultMessageTransportType uint32 = 4
+)
+
var (
- MessageInitiationType uint32 = 1
- MessageResponseType uint32 = 2
- MessageCookieReplyType uint32 = 3
- MessageTransportType uint32 = 4
+ MessageInitiationType uint32 = DefaultMessageInitiationType
+ MessageResponseType uint32 = DefaultMessageResponseType
+ MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
+ MessageTransportType uint32 = DefaultMessageTransportType
)
const (
diff --git a/device/send.go b/device/send.go
index c0b9ad4..ac05313 100644
--- a/device/send.go
+++ b/device/send.go
@@ -175,7 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
@@ -237,7 +237,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
@@ -286,7 +286,7 @@ func (device *Device) SendHandshakeCookie(
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
- if packet, err = device.codecPacket(packet); err != nil {
+ if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
@@ -547,10 +547,10 @@ func calculatePaddingSize(packetSize, mtu int) int {
return paddedSize - lastUnit
}
-func (device *Device) codecPacket(packet []byte) ([]byte, error) {
+func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
if device.luaAdapter != nil {
var err error
- packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1))
+ packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
@@ -603,9 +603,9 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
- // TODO: check
+
var err error
- if elem.packet, err = device.codecPacket(elem.packet); err != nil {
+ if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}
From e08105f85b8b1da6ae558d5c14672c34e2338a47 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:36:41 +0100
Subject: [PATCH 06/15] migrate counter
---
adapter/lua.go | 11 ++++++++---
adapter/lua_test.go | 4 ++--
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index e44ca70..1a845c4 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -3,13 +3,15 @@ package adapter
import (
"encoding/base64"
"fmt"
+ "sync/atomic"
"github.com/aarzilli/golua/lua"
)
// TODO: aSec sync is enough?
type Lua struct {
- state *lua.State
+ state *lua.State
+ packetCounter atomic.Int64
}
type LuaParams struct {
@@ -36,12 +38,15 @@ func (l *Lua) Close() {
l.state.Close()
}
-func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) {
+func (l *Lua) Generate(
+ msgType int64,
+ data []byte,
+) ([]byte, error) {
l.state.GetGlobal("d_gen")
l.state.PushInteger(msgType)
l.state.PushBytes(data)
- l.state.PushInteger(counter)
+ l.state.PushInteger(l.packetCounter.Add(1))
if err := l.state.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 828b555..c7027d0 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
- got, err := l.Generate(1, []byte("test"), 10)
+ got, err := l.Generate(1, []byte("test"))
if err != nil {
t.Errorf(
"Lua.Generate() error = %v, wantErr %v",
@@ -36,7 +36,7 @@ func TestLua_Generate(t *testing.T) {
return
}
- want := "10headertest"
+ want := "1headertest"
if string(got) != want {
t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
}
From 6b0bbcc75cb88c1cf1370ca40b8907e378b205e8 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:36:53 +0100
Subject: [PATCH 07/15] migrate awg specific code to struct
---
device/device.go | 65 ++++++++++++++++++++-----------------
device/junk_creator.go | 10 +++---
device/junk_creator_test.go | 8 ++---
device/noise-protocol.go | 20 ++++++------
device/receive.go | 12 +++----
device/send.go | 34 +++++++++----------
device/uapi.go | 36 ++++++++++----------
7 files changed, 95 insertions(+), 90 deletions(-)
diff --git a/device/device.go b/device/device.go
index aaddcbc..f375663 100644
--- a/device/device.go
+++ b/device/device.go
@@ -92,14 +92,17 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
+
+ awg awgType
+}
+type awgType struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
luaAdapter *adapter.Lua
- packetCounter atomic.Int64
}
type aSecCfgType struct {
@@ -431,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
- if device.luaAdapter != nil {
- device.luaAdapter.Close()
- device.luaAdapter = nil
+ if device.awg.luaAdapter != nil {
+ device.awg.luaAdapter.Close()
+ device.awg.luaAdapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -581,7 +584,7 @@ func (device *Device) BindClose() error {
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
- return device.isASecOn.IsSet()
+ return device.awg.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
@@ -594,36 +597,36 @@ func (device *Device) resetProtocol() {
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
- return err
+ return nil
}
isASecOn := false
- device.aSecMux.Lock()
+ device.awg.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
- device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
+ device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
- device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
+ device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
- if device.aSecCfg.junkPacketCount > 0 &&
+ if device.awg.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
- device.aSecCfg.junkPacketMinSize = 0
- device.aSecCfg.junkPacketMaxSize = 1
+ device.awg.aSecCfg.junkPacketMinSize = 0
+ device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@@ -658,7 +661,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
+ device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
@@ -683,7 +686,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
+ device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
@@ -708,7 +711,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
)
}
} else {
- device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
+ device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
@@ -718,8 +721,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
- MessageInitiationType = device.aSecCfg.initPacketMagicHeader
+ device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
+ MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
@@ -728,8 +731,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
- MessageResponseType = device.aSecCfg.responsePacketMagicHeader
+ device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
+ MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
@@ -738,8 +741,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
- MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
+ device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
+ MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
@@ -748,8 +751,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
- MessageTransportType = device.aSecCfg.transportPacketMagicHeader
+ device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
+ MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
@@ -785,8 +788,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
}
- newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
- newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
+ newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
+ newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
@@ -814,16 +817,18 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
msgTypeToJunkSize = map[uint32]int{
- MessageInitiationType: device.aSecCfg.initPacketJunkSize,
- MessageResponseType: device.aSecCfg.responsePacketJunkSize,
+ MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
+ MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
- device.isASecOn.SetTo(isASecOn)
- device.junkCreator, err = NewJunkCreator(device)
- device.aSecMux.Unlock()
+ device.awg.isASecOn.SetTo(isASecOn)
+ if device.awg.isASecOn.IsSet() {
+ device.awg.junkCreator, err = NewJunkCreator(device)
+ }
+ device.awg.aSecMux.Unlock()
return err
}
diff --git a/device/junk_creator.go b/device/junk_creator.go
index 85a5bbc..2bfb197 100644
--- a/device/junk_creator.go
+++ b/device/junk_creator.go
@@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
- if jc.device.aSecCfg.junkPacketCount == 0 {
+ if jc.device.awg.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
- junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
- for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
+ junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount)
+ for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
@@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
- jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
+ jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
),
- ) + jc.device.aSecCfg.junkPacketMinSize
+ ) + jc.device.awg.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go
index 6f63360..96fa6d3 100644
--- a/device/junk_creator_test.go
+++ b/device/junk_creator_test.go
@@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
- if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
- got > jc.device.aSecCfg.junkPacketMaxSize {
+ if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
+ got > jc.device.awg.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
- jc.device.aSecCfg.junkPacketMinSize,
- jc.device.aSecCfg.junkPacketMaxSize,
+ jc.device.awg.aSecCfg.junkPacketMinSize,
+ jc.device.awg.aSecCfg.junkPacketMaxSize,
)
}
})
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index d818238..2dc0d93 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
if msg.Type != MessageInitiationType {
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
return nil
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
msg.Type = MessageResponseType
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
if msg.Type != MessageResponseType {
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
return nil
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index e790048..70801ef 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -138,8 +138,8 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet
packet := bufsArrs[i][:size]
- if device.luaAdapter != nil {
- packet, err = device.luaAdapter.Parse(packet)
+ if device.awg.luaAdapter != nil {
+ packet, err = device.awg.luaAdapter.Parse(packet)
if err != nil {
device.log.Verbosef("Couldn't parse message; reason: %v", err)
continue
@@ -251,7 +251,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -310,7 +310,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.aSecMux.RLock()
+ device.awg.aSecMux.RLock()
// handle cookie fields and ratelimiting
@@ -462,7 +462,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.aSecMux.RUnlock()
+ device.awg.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index ac05313..297f19e 100644
--- a/device/send.go
+++ b/device/send.go
@@ -131,9 +131,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.aSecMux.RLock()
- junks, err := peer.device.junkCreator.createJunkPackets(peer)
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RLock()
+ junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
+ peer.device.awg.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@@ -153,19 +153,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
- peer.device.aSecMux.RLock()
- if peer.device.aSecCfg.initPacketJunkSize != 0 {
- buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
+ peer.device.awg.aSecMux.RLock()
+ if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
+ err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
@@ -215,19 +215,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.aSecMux.RLock()
- if peer.device.aSecCfg.responsePacketJunkSize != 0 {
- buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
+ peer.device.awg.aSecMux.RLock()
+ if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
+ err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.aSecMux.RUnlock()
+ peer.device.awg.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@@ -548,9 +548,9 @@ func calculatePaddingSize(packetSize, mtu int) int {
}
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
- if device.luaAdapter != nil {
+ if device.awg.luaAdapter != nil {
var err error
- packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
+ packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
diff --git a/device/uapi.go b/device/uapi.go
index 777bdda..1bce39b 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -98,32 +98,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAdvancedSecurityOn() {
- if device.aSecCfg.junkPacketCount != 0 {
- sendf("jc=%d", device.aSecCfg.junkPacketCount)
+ if device.awg.aSecCfg.junkPacketCount != 0 {
+ sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
}
- if device.aSecCfg.junkPacketMinSize != 0 {
- sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
+ if device.awg.aSecCfg.junkPacketMinSize != 0 {
+ sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
}
- if device.aSecCfg.junkPacketMaxSize != 0 {
- sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
+ if device.awg.aSecCfg.junkPacketMaxSize != 0 {
+ sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
}
- if device.aSecCfg.initPacketJunkSize != 0 {
- sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
+ if device.awg.aSecCfg.initPacketJunkSize != 0 {
+ sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
}
- if device.aSecCfg.responsePacketJunkSize != 0 {
- sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
+ if device.awg.aSecCfg.responsePacketJunkSize != 0 {
+ sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
}
- if device.aSecCfg.initPacketMagicHeader != 0 {
- sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
+ if device.awg.aSecCfg.initPacketMagicHeader != 0 {
+ sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
}
- if device.aSecCfg.responsePacketMagicHeader != 0 {
- sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
+ if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
+ sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
}
- if device.aSecCfg.underloadPacketMagicHeader != 0 {
- sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
+ if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
+ sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
}
- if device.aSecCfg.transportPacketMagicHeader != 0 {
- sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
+ if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
+ sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
}
}
From f4bc11733deb06a12cec3f7f345d4da281847e37 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:50:24 +0100
Subject: [PATCH 08/15] add lua codec parsing
---
adapter/lua.go | 7 ++++-
device/device.go | 75 ++++++++++++++++++++++++------------------------
device/uapi.go | 59 ++++++++++++++++++++++---------------
3 files changed, 80 insertions(+), 61 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 1a845c4..371878d 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -12,6 +12,7 @@ import (
type Lua struct {
state *lua.State
packetCounter atomic.Int64
+ base64LuaCode string
}
type LuaParams struct {
@@ -31,7 +32,7 @@ func NewLua(params LuaParams) (*Lua, error) {
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
- return &Lua{state: state}, nil
+ return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
}
func (l *Lua) Close() {
@@ -71,3 +72,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
return result, nil
}
+
+func (l *Lua) Base64LuaCode() string {
+ return l.base64LuaCode
+}
diff --git a/device/device.go b/device/device.go
index f375663..b42ddc0 100644
--- a/device/device.go
+++ b/device/device.go
@@ -595,43 +595,43 @@ func (device *Device) resetProtocol() {
MessageTransportType = DefaultMessageTransportType
}
-func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
- if !tempASecCfg.isSet {
+func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
+ if !tempAwgType.aSecCfg.isSet {
return nil
}
isASecOn := false
device.awg.aSecMux.Lock()
- if tempASecCfg.junkPacketCount < 0 {
+ if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
- device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
- if tempASecCfg.junkPacketCount != 0 {
+ device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount
+ if tempAwgType.aSecCfg.junkPacketCount != 0 {
isASecOn = true
}
- device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
- if tempASecCfg.junkPacketMinSize != 0 {
+ device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize
+ if tempAwgType.aSecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.aSecCfg.junkPacketCount > 0 &&
- tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
+ tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize {
- tempASecCfg.junkPacketMaxSize++ // to make rand gen work
+ tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work
}
- if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
+ if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.awg.aSecCfg.junkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
- tempASecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
@@ -639,41 +639,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
- tempASecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
- } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
+ } else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
- tempASecCfg.junkPacketMaxSize,
- tempASecCfg.junkPacketMinSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
- tempASecCfg.junkPacketMaxSize,
- tempASecCfg.junkPacketMinSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMinSize,
)
}
} else {
- device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
+ device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize
}
- if tempASecCfg.junkPacketMaxSize != 0 {
+ if tempAwgType.aSecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
- if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
+ if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
- tempASecCfg.initPacketJunkSize,
+ tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
@@ -681,24 +681,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempASecCfg.initPacketJunkSize,
+ tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
- device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
+ device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize
}
- if tempASecCfg.initPacketJunkSize != 0 {
+ if tempAwgType.aSecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
- if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
+ if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
- tempASecCfg.responsePacketJunkSize,
+ tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
@@ -706,52 +706,52 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempASecCfg.responsePacketJunkSize,
+ tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
- device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
+ device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize
}
- if tempASecCfg.responsePacketJunkSize != 0 {
+ if tempAwgType.aSecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
- if tempASecCfg.initPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
+ device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
}
- if tempASecCfg.responsePacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
+ device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
}
- if tempASecCfg.underloadPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
+ device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
}
- if tempASecCfg.transportPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
+ device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
@@ -828,6 +828,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
+ device.awg.luaAdapter = tempAwgType.luaAdapter
device.awg.aSecMux.Unlock()
return err
diff --git a/device/uapi.go b/device/uapi.go
index 1bce39b..a2bb21b 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,6 +18,7 @@ import (
"sync"
"time"
+ "github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@@ -97,6 +98,9 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
+ if device.awg.luaAdapter != nil {
+ sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
+ }
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
@@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
- tempASecCfg := aSecCfgType{}
+ tempAwgTpe := awgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
- err := device.handlePostConfig(&tempASecCfg)
+ err := device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
- err = device.handleDeviceLine(key, value, &tempASecCfg)
+ err = device.handleDeviceLine(key, value, &tempAwgTpe)
} else {
err = device.handlePeerLine(peer, key, value)
}
@@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
- err = device.handlePostConfig(&tempASecCfg)
+ err = device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
-func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
+func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
+ case "lua_codec":
+ device.log.Verbosef("UAPI: Updating lua_codec")
+ var err error
+ tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
+ Base64LuaCode: value,
+ })
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
+ }
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
- tempASecCfg.junkPacketCount = junkPacketCount
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
+ tempAwgTpe.aSecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
- tempASecCfg.junkPacketMinSize = junkPacketMinSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
+ tempAwgTpe.aSecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
- tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
+ tempAwgTpe.aSecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
- tempASecCfg.initPacketJunkSize = initPacketJunkSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
+ tempAwgTpe.aSecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
- tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
+ tempAwgTpe.aSecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
- tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
- tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
- tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
- tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
From 1e532c1e71f36e9cc1ce0fa7df2f8447a1d08f02 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 18:26:57 +0100
Subject: [PATCH 09/15] awg-2 working with identity generator
---
adapter/lua.go | 1 +
adapter/lua_test.go | 6 +++---
device/device.go | 14 +++++++++++++-
device/device_test.go | 9 ++++++---
device/receive.go | 29 +++++++++++++++++++++++++++--
device/send.go | 17 ++++-------------
6 files changed, 54 insertions(+), 22 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 371878d..838ff84 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -69,6 +69,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
result := l.state.ToBytes(-1)
l.state.Pop(1)
+ // copy(data, result)
return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index c7027d0..1a3d039 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -13,11 +13,11 @@ func newLua() *Lua {
end
function d_parse(data)
- local header = "10header"
+ local header = "1header"
return string.sub(data, #header+1)
end
*/
- Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK",
+ Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjFoZWFkZXIiCglyZXR1cm4gc3RyaW5nLnN1YihkYXRhLCAjaGVhZGVyKzEpCmVuZAo=",
})
return lua
}
@@ -47,7 +47,7 @@ func TestLua_Parse(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
- got, err := l.Parse([]byte("10headertest"))
+ got, err := l.Parse([]byte("1headertest"))
if err != nil {
t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
return
diff --git a/device/device.go b/device/device.go
index b42ddc0..b16b460 100644
--- a/device/device.go
+++ b/device/device.go
@@ -92,7 +92,7 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
-
+
awg awgType
}
@@ -833,3 +833,15 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
return err
}
+
+func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
+ if device.awg.luaAdapter != nil {
+ var err error
+ packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
+ if err != nil {
+ device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
+ return nil, err
+ }
+ }
+ return packet, nil
+}
diff --git a/device/device_test.go b/device/device_test.go
index e904f26..63f20df 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -107,6 +107,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
+ "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -114,8 +115,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"s2", "40",
"h1", "123456",
"h2", "67543",
- "h4", "32345",
"h3", "123123",
+ "h4", "32345",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
@@ -129,6 +130,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
+ "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -136,8 +138,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"s2", "40",
"h1", "123456",
"h2", "67543",
- "h4", "32345",
"h3", "123123",
+ "h4", "32345",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
@@ -192,6 +194,7 @@ func (pair *testPair) Send(
var err error
select {
case msgRecv := <-p0.tun.Inbound:
+ fmt.Printf("len(%d) msg: %x\nlen(%d) rec: %x\n", len(msg), msg, len(msgRecv), msgRecv)
if !bytes.Equal(msg, msgRecv) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
@@ -275,7 +278,7 @@ func TestTwoDevicePing(t *testing.T) {
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
-func TestTwoDevicePingASecurity(t *testing.T) {
+func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
diff --git a/device/receive.go b/device/receive.go
index 70801ef..f8551ae 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -9,9 +9,11 @@ import (
"bytes"
"encoding/binary"
"errors"
+ "fmt"
"net"
"sync"
"time"
+ "unsafe"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
@@ -138,8 +140,24 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet
packet := bufsArrs[i][:size]
+ fmt.Printf("bufsArrs size: %d\n%.100x\n", size, bufsArrs[i])
+ fmt.Printf("packet before: %x\n", packet)
if device.awg.luaAdapter != nil {
- packet, err = device.awg.luaAdapter.Parse(packet)
+ ptr:= unsafe.Pointer(bufsArrs[i]) // Get pointer to the array
+ slicePtr:= (*byte)(ptr) // Type conversion to the array type
+
+ realPacket, err := device.awg.luaAdapter.Parse(packet)
+ // Copy data from newSlice to the memory pointed to by slicedPtr
+ newSliceLen:= len(realPacket)
+ for j:= 0; j < newSliceLen; j++ {
+ *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(slicePtr)) + uintptr(j))) = realPacket[j]
+ }
+ fmt.Printf("packet after: %x\n", packet)
+ fmt.Printf("bufsArs after size: %d\n%.100x\n", size, bufsArrs[i])
+ // diff := size - len(packet)
+ // bufsArrs[i][:len(packet)] = bufsArrs[i][diff:len(packet)]
+ size = len(packet)
+ fmt.Println("after size: ", size)
if err != nil {
device.log.Verbosef("Couldn't parse message; reason: %v", err)
continue
@@ -151,7 +169,7 @@ func (device *Device) RoutineReceiveIncoming(
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
- msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
+ msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
@@ -285,15 +303,18 @@ func (device *Device) RoutineDecryption(id int) {
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ fmt.Printf("before decrypt: %x\n", elem.packet)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
+
if err != nil {
elem.packet = nil
}
+ fmt.Printf("decrypt: %x\n", elem.packet)
}
elemsContainer.Unlock()
}
@@ -551,10 +572,13 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
continue
}
+ fmt.Printf("bufs packet: %x\n", elem.packet)
+ fmt.Printf("bufs packet: %x\n", elem.buffer[len(elem.packet)+1:MessageTransportOffsetContent+len(elem.packet)])
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
+ fmt.Printf("bufs before send: %.100x\n", elem.buffer)
}
peer.rxBytes.Add(rxBytesLen)
@@ -568,6 +592,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
peer.timersDataReceived()
}
if len(bufs) > 0 {
+ fmt.Printf("bufs: %x\n", bufs)
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
diff --git a/device/send.go b/device/send.go
index 297f19e..85ad169 100644
--- a/device/send.go
+++ b/device/send.go
@@ -9,6 +9,7 @@ import (
"bytes"
"encoding/binary"
"errors"
+ "fmt"
"net"
"os"
"sync"
@@ -547,18 +548,6 @@ func calculatePaddingSize(packetSize, mtu int) int {
return paddedSize - lastUnit
}
-func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
- if device.awg.luaAdapter != nil {
- var err error
- packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
- if err != nil {
- device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
- return nil, err
- }
- }
- return packet, nil
-}
-
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
@@ -603,11 +592,12 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
-
+ fmt.Printf("msg: %x\n", elem.packet)
var err error
if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
+ fmt.Printf("msgmsg: %x\n", elem.packet)
}
elemsContainer.Unlock()
}
@@ -662,6 +652,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
+ fmt.Printf("send buffer: %.200x\n", elem.buffer)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
From 068e1465387fff82501450532587ee58be60dff0 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 18:34:36 +0100
Subject: [PATCH 10/15] working codec
---
adapter/lua.go | 1 -
device/device_test.go | 5 ++---
device/receive.go | 35 +++++++++++------------------------
device/send.go | 4 ----
4 files changed, 13 insertions(+), 32 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 838ff84..371878d 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -69,7 +69,6 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
result := l.state.ToBytes(-1)
l.state.Pop(1)
- // copy(data, result)
return result, nil
}
diff --git a/device/device_test.go b/device/device_test.go
index 63f20df..367f097 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -107,7 +107,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=",
+ "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -130,7 +130,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlyZXR1cm4gZGF0YQoJCQllbmQKCgkJCWZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCQkJCXJldHVybiBkYXRhCgkJCWVuZAo=",
+ "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -194,7 +194,6 @@ func (pair *testPair) Send(
var err error
select {
case msgRecv := <-p0.tun.Inbound:
- fmt.Printf("len(%d) msg: %x\nlen(%d) rec: %x\n", len(msg), msg, len(msgRecv), msgRecv)
if !bytes.Equal(msg, msgRecv) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
diff --git a/device/receive.go b/device/receive.go
index f8551ae..4ac5783 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -9,7 +9,6 @@ import (
"bytes"
"encoding/binary"
"errors"
- "fmt"
"net"
"sync"
"time"
@@ -138,28 +137,22 @@ func (device *Device) RoutineReceiveIncoming(
continue
}
- // check size of packet
packet := bufsArrs[i][:size]
- fmt.Printf("bufsArrs size: %d\n%.100x\n", size, bufsArrs[i])
- fmt.Printf("packet before: %x\n", packet)
if device.awg.luaAdapter != nil {
- ptr:= unsafe.Pointer(bufsArrs[i]) // Get pointer to the array
- slicePtr:= (*byte)(ptr) // Type conversion to the array type
-
realPacket, err := device.awg.luaAdapter.Parse(packet)
- // Copy data from newSlice to the memory pointed to by slicedPtr
- newSliceLen:= len(realPacket)
- for j:= 0; j < newSliceLen; j++ {
- *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(slicePtr)) + uintptr(j))) = realPacket[j]
+
+ packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array
+ // Copy data from realPacket to the memory pointed to by bufsArrs[i]
+ size = len(realPacket)
+ for j := 0; j < size; j++ {
+ *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j]
}
- fmt.Printf("packet after: %x\n", packet)
- fmt.Printf("bufsArs after size: %d\n%.100x\n", size, bufsArrs[i])
- // diff := size - len(packet)
- // bufsArrs[i][:len(packet)] = bufsArrs[i][diff:len(packet)]
- size = len(packet)
- fmt.Println("after size: ", size)
+ packet = bufsArrs[i][:size]
if err != nil {
- device.log.Verbosef("Couldn't parse message; reason: %v", err)
+ device.log.Verbosef(
+ "Couldn't parse message; reason: %v",
+ err,
+ )
continue
}
}
@@ -303,7 +296,6 @@ func (device *Device) RoutineDecryption(id int) {
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- fmt.Printf("before decrypt: %x\n", elem.packet)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
@@ -314,7 +306,6 @@ func (device *Device) RoutineDecryption(id int) {
if err != nil {
elem.packet = nil
}
- fmt.Printf("decrypt: %x\n", elem.packet)
}
elemsContainer.Unlock()
}
@@ -572,13 +563,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
continue
}
- fmt.Printf("bufs packet: %x\n", elem.packet)
- fmt.Printf("bufs packet: %x\n", elem.buffer[len(elem.packet)+1:MessageTransportOffsetContent+len(elem.packet)])
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
- fmt.Printf("bufs before send: %.100x\n", elem.buffer)
}
peer.rxBytes.Add(rxBytesLen)
@@ -592,7 +580,6 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
peer.timersDataReceived()
}
if len(bufs) > 0 {
- fmt.Printf("bufs: %x\n", bufs)
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
diff --git a/device/send.go b/device/send.go
index 85ad169..5acd81c 100644
--- a/device/send.go
+++ b/device/send.go
@@ -9,7 +9,6 @@ import (
"bytes"
"encoding/binary"
"errors"
- "fmt"
"net"
"os"
"sync"
@@ -592,12 +591,10 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
- fmt.Printf("msg: %x\n", elem.packet)
var err error
if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
- fmt.Printf("msgmsg: %x\n", elem.packet)
}
elemsContainer.Unlock()
}
@@ -652,7 +649,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
- fmt.Printf("send buffer: %.200x\n", elem.buffer)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
From 0dcd528694ad8c9babd807cb38bce2aea64d3126 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 10 Feb 2025 10:05:42 +0100
Subject: [PATCH 11/15] rename vars and complete testing
---
adapter/lua.go | 55 +++++++++++++++++++++++++++-------------
adapter/lua_test.go | 22 ++++++++++++++++
device/device.go | 26 +++++++++++--------
device/device_test.go | 15 +++++++++--
device/noise-protocol.go | 20 +++++++--------
device/receive.go | 20 +++++----------
device/send.go | 25 +++++++++---------
device/uapi.go | 6 ++---
8 files changed, 121 insertions(+), 68 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 371878d..4677345 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -8,9 +8,9 @@ import (
"github.com/aarzilli/golua/lua"
)
-// TODO: aSec sync is enough?
type Lua struct {
- state *lua.State
+ generateState *lua.State
+ parseState *lua.State
packetCounter atomic.Int64
base64LuaCode string
}
@@ -24,51 +24,72 @@ func NewLua(params LuaParams) (*Lua, error) {
if err != nil {
return nil, err
}
- fmt.Println(string(luaCode))
+ strLuaCode := string(luaCode)
+
+ generateState, err := initState(strLuaCode)
+ if err != nil {
+ return nil, err
+ }
+ parseState, err := initState(strLuaCode)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Lua{
+ generateState: generateState,
+ parseState: parseState,
+ base64LuaCode: params.Base64LuaCode,
+ }, nil
+}
+
+func initState(luaCode string) (*lua.State, error) {
state := lua.NewState()
state.OpenLibs()
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
- return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
+ return state, nil
}
func (l *Lua) Close() {
- l.state.Close()
+ l.generateState.Close()
+ l.parseState.Close()
}
+// Only thread safe if used by wg packet creation which happens independably
func (l *Lua) Generate(
msgType int64,
data []byte,
) ([]byte, error) {
- l.state.GetGlobal("d_gen")
+ l.generateState.GetGlobal("d_gen")
- l.state.PushInteger(msgType)
- l.state.PushBytes(data)
- l.state.PushInteger(l.packetCounter.Add(1))
+ l.generateState.PushInteger(msgType)
+ l.generateState.PushBytes(data)
+ l.generateState.PushInteger(l.packetCounter.Add(1))
- if err := l.state.Call(3, 1); err != nil {
+ if err := l.generateState.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := l.state.ToBytes(-1)
- l.state.Pop(1)
+ result := l.generateState.ToBytes(-1)
+ l.generateState.Pop(1)
return result, nil
}
+// Only thread safe if used by wg packet receive which happens independably
func (l *Lua) Parse(data []byte) ([]byte, error) {
- l.state.GetGlobal("d_parse")
+ l.parseState.GetGlobal("d_parse")
- l.state.PushBytes(data)
- if err := l.state.Call(1, 1); err != nil {
+ l.parseState.PushBytes(data)
+ if err := l.parseState.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := l.state.ToBytes(-1)
- l.state.Pop(1)
+ result := l.parseState.ToBytes(-1)
+ l.parseState.Pop(1)
return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 1a3d039..88389f5 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -58,3 +58,25 @@ func TestLua_Parse(t *testing.T) {
}
})
}
+
+
+var R []byte
+var l = newLua()
+func BenchmarkLuaGenerate(b *testing.B) {
+ for i := 0; i < b.N; i++ { var err error
+ R, err = l.Generate(1, []byte("test"))
+ if err != nil {
+ return
+ }
+ }
+}
+
+func BenchmarkLuaParse(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var err error
+ R, err = l.Parse([]byte("1headertest"))
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/device/device.go b/device/device.go
index b16b460..8a96ffd 100644
--- a/device/device.go
+++ b/device/device.go
@@ -98,11 +98,11 @@ type Device struct {
type awgType struct {
isASecOn abool.AtomicBool
- aSecMux sync.RWMutex
+ mutex sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
- luaAdapter *adapter.Lua
+ codec *adapter.Lua
}
type aSecCfgType struct {
@@ -434,9 +434,9 @@ func (device *Device) Close() {
device.resetProtocol()
- if device.awg.luaAdapter != nil {
- device.awg.luaAdapter.Close()
- device.awg.luaAdapter = nil
+ if device.awg.codec != nil {
+ device.awg.codec.Close()
+ device.awg.codec = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -601,7 +601,7 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
}
isASecOn := false
- device.awg.aSecMux.Lock()
+ device.awg.mutex.Lock()
if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
@@ -828,16 +828,20 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
- device.awg.luaAdapter = tempAwgType.luaAdapter
- device.awg.aSecMux.Unlock()
+ device.awg.codec = tempAwgType.codec
+ device.awg.mutex.Unlock()
return err
}
-func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
- if device.awg.luaAdapter != nil {
+func (device *Device) isCodecActive() bool {
+ return device.awg.codec != nil
+}
+
+func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
+ if device.isCodecActive() {
var err error
- packet, err = device.awg.luaAdapter.Generate(int64(msgType),packet)
+ packet, err = device.awg.codec.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
diff --git a/device/device_test.go b/device/device_test.go
index 367f097..0f42c46 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -103,11 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
+ /*
+ head = "headheadhead"
+ tail = "tailtailtail"
+ function d_gen(msg_type, data, counter)
+ return head .. data .. tail
+ end
+
+ function d_parse(data)
+ return string.match(data, head.. "(.-)".. tail)
+ end
+ */
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
+ "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@@ -130,7 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
- "lua_codec", "CQkJZnVuY3Rpb24gZF9nZW4obXNnX3R5cGUsIGRhdGEsIGNvdW50ZXIpCgkJCQlsb2NhbCBoZWFkZXIgPSAiaGVhZGVyIgoJCQkJcmV0dXJuIGhlYWRlciAuLiBkYXRhCgkJCWVuZAoKCQkJZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJCQkJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCQkJCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKCQkJZW5kCg==",
+ "lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 2dc0d93..75352c0 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
if msg.Type != MessageInitiationType {
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
return nil
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
msg.Type = MessageResponseType
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
if msg.Type != MessageResponseType {
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
return nil
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index 4ac5783..3df269e 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -12,7 +12,6 @@ import (
"net"
"sync"
"time"
- "unsafe"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
@@ -130,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -138,15 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
}
packet := bufsArrs[i][:size]
- if device.awg.luaAdapter != nil {
- realPacket, err := device.awg.luaAdapter.Parse(packet)
-
- packetPtr := (*byte)(unsafe.Pointer(bufsArrs[i])) // Get pointer to the array
- // Copy data from realPacket to the memory pointed to by bufsArrs[i]
+ if device.isCodecActive() {
+ realPacket, err := device.awg.codec.Parse(packet)
+ copy(packet, realPacket)
size = len(realPacket)
- for j := 0; j < size; j++ {
- *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(packetPtr)) + uintptr(j))) = realPacket[j]
- }
packet = bufsArrs[i][:size]
if err != nil {
device.log.Verbosef(
@@ -262,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -322,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.awg.aSecMux.RLock()
+ device.awg.mutex.RLock()
// handle cookie fields and ratelimiting
@@ -474,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.awg.aSecMux.RUnlock()
+ device.awg.mutex.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index 5acd81c..d6352e9 100644
--- a/device/send.go
+++ b/device/send.go
@@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
)
return err
}
+
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
}
var buf [MessageInitiationSize]byte
@@ -175,7 +176,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
@@ -215,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
- peer.device.awg.aSecMux.RLock()
+ peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
- peer.device.awg.aSecMux.RUnlock()
+ peer.device.awg.mutex.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@@ -237,7 +238,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
@@ -286,7 +287,7 @@ func (device *Device) SendHandshakeCookie(
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
- if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
+ if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
@@ -592,7 +593,7 @@ func (device *Device) RoutineEncryption(id int) {
nil,
)
var err error
- if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
+ if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}
diff --git a/device/uapi.go b/device/uapi.go
index a2bb21b..850cc6c 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- if device.awg.luaAdapter != nil {
- sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
+ if device.awg.codec != nil {
+ sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
@@ -290,7 +290,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
var err error
- tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
+ tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
Base64LuaCode: value,
})
if err != nil {
From ae81fcbd7ab0ae44d68ebcc2c80b105aa948e3d1 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 10 Feb 2025 10:40:34 +0100
Subject: [PATCH 12/15] delete adapter makefile
---
adapter/Makefile | 2 --
1 file changed, 2 deletions(-)
delete mode 100644 adapter/Makefile
diff --git a/adapter/Makefile b/adapter/Makefile
deleted file mode 100644
index 41bc7aa..0000000
--- a/adapter/Makefile
+++ /dev/null
@@ -1,2 +0,0 @@
-all:
- go run -tags lua54 lua.go
From d03bddd7c77b585555def6d78ecbe45e0b9c876a Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 10 Feb 2025 10:43:18 +0100
Subject: [PATCH 13/15] add luajit tag for lua codec
---
Makefile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Makefile b/Makefile
index 7a88647..979681a 100644
--- a/Makefile
+++ b/Makefile
@@ -17,7 +17,7 @@ generate-version-and-build:
@$(MAKE) amneziawg-go
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
- go build -v -o "$@"
+ go build -tags luajit -v -o "$@"
install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
From 02462a0e2fc02e78a2282c1387ba0cafc5f11040 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Tue, 11 Feb 2025 08:08:22 +0100
Subject: [PATCH 14/15] always assign lua_codec in postConfig
---
device/device.go | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/device/device.go b/device/device.go
index 8a96ffd..6a83bf9 100644
--- a/device/device.go
+++ b/device/device.go
@@ -98,7 +98,7 @@ type Device struct {
type awgType struct {
isASecOn abool.AtomicBool
- mutex sync.RWMutex
+ mutex sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
@@ -596,6 +596,8 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
+ device.awg.codec = tempAwgType.codec
+
if !tempAwgType.aSecCfg.isSet {
return nil
}
@@ -828,7 +830,6 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
- device.awg.codec = tempAwgType.codec
device.awg.mutex.Unlock()
return err
From c486c2eb5d46a1cd4df09b7818b565dca064561a Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Tue, 11 Feb 2025 08:40:08 +0100
Subject: [PATCH 15/15] relocate adapter to internal structure
---
device/device.go | 2 +-
{adapter => device/internal/adapter}/lua.go | 1 +
{adapter => device/internal/adapter}/lua_test.go | 0
device/uapi.go | 2 +-
4 files changed, 3 insertions(+), 2 deletions(-)
rename {adapter => device/internal/adapter}/lua.go (98%)
rename {adapter => device/internal/adapter}/lua_test.go (100%)
diff --git a/device/device.go b/device/device.go
index 6a83bf9..1e28e8f 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,8 +11,8 @@ import (
"sync/atomic"
"time"
- "github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/conn"
+ "github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
diff --git a/adapter/lua.go b/device/internal/adapter/lua.go
similarity index 98%
rename from adapter/lua.go
rename to device/internal/adapter/lua.go
index 4677345..9c5e07d 100644
--- a/adapter/lua.go
+++ b/device/internal/adapter/lua.go
@@ -26,6 +26,7 @@ func NewLua(params LuaParams) (*Lua, error) {
}
strLuaCode := string(luaCode)
+ // fmt.Println(strLuaCode)
generateState, err := initState(strLuaCode)
if err != nil {
diff --git a/adapter/lua_test.go b/device/internal/adapter/lua_test.go
similarity index 100%
rename from adapter/lua_test.go
rename to device/internal/adapter/lua_test.go
diff --git a/device/uapi.go b/device/uapi.go
index 850cc6c..7bd5209 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,7 +18,7 @@ import (
"sync"
"time"
- "github.com/amnezia-vpn/amneziawg-go/adapter"
+ "github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)