From 55f4715a5055fd1aa84d640bb0574f9e64c5ba37 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 8 Feb 2025 14:16:03 +0100
Subject: [PATCH] create lua adapter and test
---
adapter/lua.go | 79 +++++++++++++++++++++++++++++++--------------
adapter/lua_test.go | 60 ++++++++++++++++++++++++++++++++++
go.mod | 2 +-
3 files changed, 115 insertions(+), 26 deletions(-)
create mode 100644 adapter/lua_test.go
diff --git a/adapter/lua.go b/adapter/lua.go
index 8c56d4a..a6aceb1 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -1,4 +1,4 @@
-package main
+package adapter
import (
"encoding/base64"
@@ -7,39 +7,68 @@ import (
"github.com/aarzilli/golua/lua"
)
-func main() {
- // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==`
- // only d_gen
- // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
- luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
- sDec, _ := base64.StdEncoding.DecodeString(luaB64)
- fmt.Println(string(sDec))
- luaCode := sDec
- L := lua.NewState()
- L.OpenLibs()
- defer L.Close()
+// TODO: aSec sync is enough?
+type Lua struct {
+ state *lua.State
+}
+
+type LuaParams struct {
+ LuaCode64 string
+}
+
+func NewLua(params LuaParams) (*Lua, error) {
+ luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Println(string(luaCode))
+
+ state := lua.NewState()
+ state.OpenLibs()
// Load and execute the Lua code
- if err := L.DoString(string(luaCode)); err != nil {
- fmt.Printf("Error loading Lua code: %v\n", err)
- return
+ if err := state.DoString(string(luaCode)); err != nil {
+ return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
+ return &Lua{state: state}, nil
+}
+func (l *Lua) Close() {
+ l.state.Close()
+}
+
+func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
// Push the function onto the stack
- L.GetGlobal("D_gen")
+ l.state.GetGlobal("D_gen")
// Push the argument
- L.PushString("data")
+ l.state.PushBytes(data)
+ l.state.PushInteger(counter)
- if err := L.Call(1, 1); err != nil {
- fmt.Printf("Error calling Lua function: %v\n", err)
- return
+ if err := l.state.Call(2, 1); err != nil {
+ return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
- result := L.ToString(-1)
- L.Pop(1)
+ result := l.state.ToBytes(-1)
+ l.state.Pop(1)
- // Print the result
- // fmt.Printf("Result: %x\n", []byte(result))
- fmt.Printf("Result: %s\n", result)
+ fmt.Printf("Result: %s\n", string(result))
+ return result, nil
+}
+
+func (l *Lua) Parse(data []byte) ([]byte, error) {
+ // Push the function onto the stack
+ l.state.GetGlobal("D_parse")
+
+ // Push the argument
+ l.state.PushBytes(data)
+ if err := l.state.Call(1, 1); err != nil {
+ return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
+ }
+
+ result := l.state.ToBytes(-1)
+ l.state.Pop(1)
+
+ fmt.Printf("Result: %s\n", string(result))
+ return result, nil
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
new file mode 100644
index 0000000..1d35651
--- /dev/null
+++ b/adapter/lua_test.go
@@ -0,0 +1,60 @@
+package adapter
+
+import (
+ "testing"
+)
+
+func newLua() *Lua {
+ lua, _ := NewLua(LuaParams{
+ /*
+ function D_gen(data, counter)
+ local header = "header"
+ return counter .. header .. data
+ end
+
+ function D_parse(data)
+ local header = "10header"
+ return string.sub(data, #header+1)
+ end
+ */
+ LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
+ })
+ return lua
+}
+
+func TestLua_Generate(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ l := newLua()
+ defer l.Close()
+ got, err := l.Generate([]byte("test"), 10)
+ if err != nil {
+ t.Errorf(
+ "Lua.Generate() error = %v, wantErr %v",
+ err,
+ nil,
+ )
+ return
+ }
+
+ want := "10headertest"
+ if string(got) != want {
+ t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
+ }
+ })
+}
+
+func TestLua_Parse(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ l := newLua()
+ defer l.Close()
+ got, err := l.Parse([]byte("10headertest"))
+ if err != nil {
+ t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
+ return
+ }
+ want := "test"
+ if string(got) != want {
+ t.Errorf("Lua.Parse() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/go.mod b/go.mod
index 8511642..435752a 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.23
require (
+ github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.21.0
@@ -12,7 +13,6 @@ require (
)
require (
- github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)