create lua adapter and test

This commit is contained in:
Mark Puha 2025-02-08 14:16:03 +01:00
parent 44217caa0d
commit 55f4715a50
3 changed files with 115 additions and 26 deletions

View file

@ -1,4 +1,4 @@
package main
package adapter
import (
"encoding/base64"
@ -7,39 +7,68 @@ import (
"github.com/aarzilli/golua/lua"
)
func main() {
// luaB64 := `bG9jYWwgZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCUhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCgktLSBsb2NhbCB0cyA9IG9zLnRpbWUoKQoJcmV0dXJuIEhlYWRlciAuLiBkYXRhCmVuZAoKbG9jYWwgZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI0hlYWRlcikKZW5kCg==`
// only d_gen
// luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCS0tIGxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
luaB64 := `ZnVuY3Rpb24gRF9nZW4oZGF0YSkKCS0tIEhlYWRlciA9IHN0cmluZy5jaGFyKDB4MTIsIDB4MzQsIDB4NTYsIDB4NzgpCglsb2NhbCBIZWFkZXIgPSAiXHgxMlx4MzRceDU2XHg3OCIKCWxvY2FsIHRzID0gb3MudGltZSgpCglyZXR1cm4gSGVhZGVyIC4uIGRhdGEKZW5kCg==`
sDec, _ := base64.StdEncoding.DecodeString(luaB64)
fmt.Println(string(sDec))
luaCode := sDec
L := lua.NewState()
L.OpenLibs()
defer L.Close()
// TODO: aSec sync is enough?
type Lua struct {
state *lua.State
}
type LuaParams struct {
LuaCode64 string
}
func NewLua(params LuaParams) (*Lua, error) {
luaCode, err := base64.StdEncoding.DecodeString(params.LuaCode64)
if err != nil {
return nil, err
}
fmt.Println(string(luaCode))
state := lua.NewState()
state.OpenLibs()
// Load and execute the Lua code
if err := L.DoString(string(luaCode)); err != nil {
fmt.Printf("Error loading Lua code: %v\n", err)
return
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
return &Lua{state: state}, nil
}
func (l *Lua) Close() {
l.state.Close()
}
func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
// Push the function onto the stack
L.GetGlobal("D_gen")
l.state.GetGlobal("D_gen")
// Push the argument
L.PushString("data")
l.state.PushBytes(data)
l.state.PushInteger(counter)
if err := L.Call(1, 1); err != nil {
fmt.Printf("Error calling Lua function: %v\n", err)
return
if err := l.state.Call(2, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := L.ToString(-1)
L.Pop(1)
result := l.state.ToBytes(-1)
l.state.Pop(1)
// Print the result
// fmt.Printf("Result: %x\n", []byte(result))
fmt.Printf("Result: %s\n", result)
fmt.Printf("Result: %s\n", string(result))
return result, nil
}
func (l *Lua) Parse(data []byte) ([]byte, error) {
// Push the function onto the stack
l.state.GetGlobal("D_parse")
// Push the argument
l.state.PushBytes(data)
if err := l.state.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := l.state.ToBytes(-1)
l.state.Pop(1)
fmt.Printf("Result: %s\n", string(result))
return result, nil
}

60
adapter/lua_test.go Normal file
View file

@ -0,0 +1,60 @@
package adapter
import (
"testing"
)
func newLua() *Lua {
lua, _ := NewLua(LuaParams{
/*
function D_gen(data, counter)
local header = "header"
return counter .. header .. data
end
function D_parse(data)
local header = "10header"
return string.sub(data, #header+1)
end
*/
LuaCode64: "ZnVuY3Rpb24gRF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIERfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
})
return lua
}
func TestLua_Generate(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
got, err := l.Generate([]byte("test"), 10)
if err != nil {
t.Errorf(
"Lua.Generate() error = %v, wantErr %v",
err,
nil,
)
return
}
want := "10headertest"
if string(got) != want {
t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
}
})
}
func TestLua_Parse(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
got, err := l.Parse([]byte("10headertest"))
if err != nil {
t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
return
}
want := "test"
if string(got) != want {
t.Errorf("Lua.Parse() = %v, want %v", got, want)
}
})
}

2
go.mod
View file

@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.23
require (
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.21.0
@ -12,7 +13,6 @@ require (
)
require (
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e // indirect
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)