add lua codec parsing

This commit is contained in:
Mark Puha 2025-02-09 09:50:24 +01:00
parent 6b0bbcc75c
commit f4bc11733d
3 changed files with 80 additions and 61 deletions

View file

@ -12,6 +12,7 @@ import (
type Lua struct {
state *lua.State
packetCounter atomic.Int64
base64LuaCode string
}
type LuaParams struct {
@ -31,7 +32,7 @@ func NewLua(params LuaParams) (*Lua, error) {
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
return &Lua{state: state}, nil
return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil
}
func (l *Lua) Close() {
@ -71,3 +72,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
return result, nil
}
func (l *Lua) Base64LuaCode() string {
return l.base64LuaCode
}

View file

@ -595,43 +595,43 @@ func (device *Device) resetProtocol() {
MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
if !tempAwgType.aSecCfg.isSet {
return nil
}
isASecOn := false
device.awg.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount
if tempAwgType.aSecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize
if tempAwgType.aSecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.awg.aSecCfg.junkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
@ -639,41 +639,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
} else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMinSize,
)
}
} else {
device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
if tempAwgType.aSecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
@ -681,24 +681,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
if tempAwgType.aSecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
@ -706,52 +706,52 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
if tempAwgType.aSecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
}
if tempASecCfg.responsePacketMagicHeader > 4 {
if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
}
if tempASecCfg.transportPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
@ -828,6 +828,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
device.awg.luaAdapter = tempAwgType.luaAdapter
device.awg.aSecMux.Unlock()
return err

View file

@ -18,6 +18,7 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@ -97,6 +98,9 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.awg.luaAdapter != nil {
sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempASecCfg := aSecCfgType{}
tempAwgTpe := awgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
err := device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg)
err = device.handleDeviceLine(key, value, &tempAwgTpe)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempASecCfg)
err = device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
var err error
tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{
Base64LuaCode: value,
})
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
}
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
tempAwgTpe.aSecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
tempAwgTpe.aSecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
tempAwgTpe.aSecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
tempAwgTpe.aSecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
tempAwgTpe.aSecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)