From f4bc11733deb06a12cec3f7f345d4da281847e37 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:50:24 +0100
Subject: [PATCH] add lua codec parsing
---
adapter/lua.go | 7 ++++-
device/device.go | 75 ++++++++++++++++++++++++------------------------
device/uapi.go | 59 ++++++++++++++++++++++---------------
3 files changed, 80 insertions(+), 61 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 1a845c4..371878d 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -12,6 +12,7 @@ import (
type Lua struct {
state *lua.State
packetCounter atomic.Int64
+ base64LuaCode string
}
type LuaParams struct {
@@ -31,7 +32,7 @@ func NewLua(params LuaParams) (*Lua, error) {
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
- return &Lua{state: state}, nil
+ return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
}
func (l *Lua) Close() {
@@ -71,3 +72,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
return result, nil
}
+
+func (l *Lua) Base64LuaCode() string {
+ return l.base64LuaCode
+}
diff --git a/device/device.go b/device/device.go
index f375663..b42ddc0 100644
--- a/device/device.go
+++ b/device/device.go
@@ -595,43 +595,43 @@ func (device *Device) resetProtocol() {
MessageTransportType = DefaultMessageTransportType
}
-func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
- if !tempASecCfg.isSet {
+func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
+ if !tempAwgType.aSecCfg.isSet {
return nil
}
isASecOn := false
device.awg.aSecMux.Lock()
- if tempASecCfg.junkPacketCount < 0 {
+ if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
- device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
- if tempASecCfg.junkPacketCount != 0 {
+ device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount
+ if tempAwgType.aSecCfg.junkPacketCount != 0 {
isASecOn = true
}
- device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
- if tempASecCfg.junkPacketMinSize != 0 {
+ device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize
+ if tempAwgType.aSecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.aSecCfg.junkPacketCount > 0 &&
- tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
+ tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize {
- tempASecCfg.junkPacketMaxSize++ // to make rand gen work
+ tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work
}
- if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
+ if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.awg.aSecCfg.junkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
- tempASecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
@@ -639,41 +639,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
- tempASecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
- } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
+ } else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
- tempASecCfg.junkPacketMaxSize,
- tempASecCfg.junkPacketMinSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
- tempASecCfg.junkPacketMaxSize,
- tempASecCfg.junkPacketMinSize,
+ tempAwgType.aSecCfg.junkPacketMaxSize,
+ tempAwgType.aSecCfg.junkPacketMinSize,
)
}
} else {
- device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
+ device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize
}
- if tempASecCfg.junkPacketMaxSize != 0 {
+ if tempAwgType.aSecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
- if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
+ if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
- tempASecCfg.initPacketJunkSize,
+ tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
@@ -681,24 +681,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempASecCfg.initPacketJunkSize,
+ tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
- device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
+ device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize
}
- if tempASecCfg.initPacketJunkSize != 0 {
+ if tempAwgType.aSecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
- if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
+ if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
- tempASecCfg.responsePacketJunkSize,
+ tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
@@ -706,52 +706,52 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempASecCfg.responsePacketJunkSize,
+ tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
- device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
+ device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize
}
- if tempASecCfg.responsePacketJunkSize != 0 {
+ if tempAwgType.aSecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
- if tempASecCfg.initPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
+ device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
}
- if tempASecCfg.responsePacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
+ device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
}
- if tempASecCfg.underloadPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
+ device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
}
- if tempASecCfg.transportPacketMagicHeader > 4 {
+ if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
+ device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
@@ -828,6 +828,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
+ device.awg.luaAdapter = tempAwgType.luaAdapter
device.awg.aSecMux.Unlock()
return err
diff --git a/device/uapi.go b/device/uapi.go
index 1bce39b..a2bb21b 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,6 +18,7 @@ import (
"sync"
"time"
+ "github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@@ -97,6 +98,9 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
+ if device.awg.luaAdapter != nil {
+ sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
+ }
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
@@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
- tempASecCfg := aSecCfgType{}
+ tempAwgTpe := awgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
- err := device.handlePostConfig(&tempASecCfg)
+ err := device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
- err = device.handleDeviceLine(key, value, &tempASecCfg)
+ err = device.handleDeviceLine(key, value, &tempAwgTpe)
} else {
err = device.handlePeerLine(peer, key, value)
}
@@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
- err = device.handlePostConfig(&tempASecCfg)
+ err = device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
-func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
+func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
+ case "lua_codec":
+ device.log.Verbosef("UAPI: Updating lua_codec")
+ var err error
+ tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
+ Base64LuaCode: value,
+ })
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
+ }
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
- tempASecCfg.junkPacketCount = junkPacketCount
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
+ tempAwgTpe.aSecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
- tempASecCfg.junkPacketMinSize = junkPacketMinSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
+ tempAwgTpe.aSecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
- tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
+ tempAwgTpe.aSecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
- tempASecCfg.initPacketJunkSize = initPacketJunkSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
+ tempAwgTpe.aSecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
- tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
+ tempAwgTpe.aSecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
- tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
- tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
- tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
- tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
- tempASecCfg.isSet = true
+ tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
+ tempAwgTpe.aSecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)