diff --git a/adapter/lua.go b/adapter/lua.go index 8c56d4a..a6aceb1 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -1,4 +1,4 @@ -package main +package adapter import ( "encoding/base64" @@ -7,39 +7,68 @@ import ( "github.com/aarzilli/golua/lua" ) -func main() { - // luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==` - // only d_gen - // luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` - luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==` - sDec, _ := base64.StdEncoding.DecodeString(luaB64) - fmt.Println(string(sDec)) - luaCode := sDec - L := lua.NewState() - L.OpenLibs() - defer L.Close() +// TODO: aSec sync is enough? +type Lua struct { + state *lua.State +} + +type LuaParams struct { + LuaCode64 string +} + +func NewLua(params LuaParams) (*Lua, error) { + luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64) + if err != nil { + return nil, err + } + fmt.Println(string(luaCode)) + + state := lua.NewState() + state.OpenLibs() // Load and execute the Lua code - if err := L.DoString(string(luaCode)); err != nil { - fmt.Printf("Error loading Lua code: %v\n", err) - return + if err := state.DoString(string(luaCode)); err != nil { + return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } + return &Lua{state: state}, nil +} +func (l *Lua) Close() { + l.state.Close() +} + +func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { // Push the function onto the stack - L.GetGlobal("D_gen") + l.state.GetGlobal("D_gen") // Push the argument - L.PushString("data") + l.state.PushBytes(data) + l.state.PushInteger(counter) - if err := L.Call(1, 1); err != nil { - fmt.Printf("Error calling Lua function: %v\n", err) - return + if err := l.state.Call(2, 1); err != nil { + return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } - result := L.ToString(-1) - L.Pop(1) + result := l.state.ToBytes(-1) + l.state.Pop(1) - // Print the result - // fmt.Printf("Result: %x\n", []byte(result)) - fmt.Printf("Result: %s\n", result) + fmt.Printf("Result: %s\n", string(result)) + return result, nil +} + +func (l *Lua) Parse(data []byte) ([]byte, error) { + // Push the function onto the stack + l.state.GetGlobal("D_parse") + + // Push the argument + l.state.PushBytes(data) + if err := l.state.Call(1, 1); err != nil { + return nil, fmt.Errorf("Error calling Lua function: %v\n", err) + } + + result := l.state.ToBytes(-1) + l.state.Pop(1) + + fmt.Printf("Result: %s\n", string(result)) + return result, nil } diff --git a/adapter/lua_test.go b/adapter/lua_test.go new file mode 100644 index 0000000..1d35651 --- /dev/null +++ b/adapter/lua_test.go @@ -0,0 +1,60 @@ +package adapter + +import ( + "testing" +) + +func newLua() *Lua { + lua, _ := NewLua(LuaParams{ + /* + function D_gen(data, counter) + local header = "header" + return counter .. header .. data + end + + function D_parse(data) + local header = "10header" + return string.sub(data, #header+1) + end + */ + LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", + }) + return lua +} + +func TestLua_Generate(t *testing.T) { + t.Run("", func(t *testing.T) { + l := newLua() + defer l.Close() + got, err := l.Generate([]byte("test"), 10) + if err != nil { + t.Errorf( + "Lua.Generate() error = %v, wantErr %v", + err, + nil, + ) + return + } + + want := "10headertest" + if string(got) != want { + t.Errorf("Lua.Generate() = %v, want %v", string(got), want) + } + }) +} + +func TestLua_Parse(t *testing.T) { + t.Run("", func(t *testing.T) { + l := newLua() + defer l.Close() + got, err := l.Parse([]byte("10headertest")) + if err != nil { + t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil) + return + } + want := "test" + if string(got) != want { + t.Errorf("Lua.Parse() = %v, want %v", got, want) + } + }) +} diff --git a/go.mod b/go.mod index 8511642..435752a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go go 1.23 require ( + github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.21.0 golang.org/x/net v0.21.0 @@ -12,7 +13,6 @@ require ( ) require ( - github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect )