From f4bc11733deb06a12cec3f7f345d4da281847e37 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:50:24 +0100 Subject: [PATCH] add lua codec parsing --- adapter/lua.go | 7 ++++- device/device.go | 75 ++++++++++++++++++++++++------------------------ device/uapi.go | 59 ++++++++++++++++++++++--------------- 3 files changed, 80 insertions(+), 61 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 1a845c4..371878d 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -12,6 +12,7 @@ import ( type Lua struct { state *lua.State packetCounter atomic.Int64 + base64LuaCode string } type LuaParams struct { @@ -31,7 +32,7 @@ func NewLua(params LuaParams) (*Lua, error) { if err := state.DoString(string(luaCode)); err != nil { return nil, fmt.Errorf("Error loading Lua code: %v\n", err) } - return &Lua{state: state}, nil + return &Lua{state: state, base64LuaCode: params.Base64LuaCode}, nil } func (l *Lua) Close() { @@ -71,3 +72,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { return result, nil } + +func (l *Lua) Base64LuaCode() string { + return l.base64LuaCode +} diff --git a/device/device.go b/device/device.go index f375663..b42ddc0 100644 --- a/device/device.go +++ b/device/device.go @@ -595,43 +595,43 @@ func (device *Device) resetProtocol() { MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { - if !tempASecCfg.isSet { +func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { + if !tempAwgType.aSecCfg.isSet { return nil } isASecOn := false device.awg.aSecMux.Lock() - if tempASecCfg.junkPacketCount < 0 { + if tempAwgType.aSecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ) } - device.awg.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount - if tempASecCfg.junkPacketCount != 0 { + device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount + if tempAwgType.aSecCfg.junkPacketCount != 0 { isASecOn = true } - device.awg.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize - if tempASecCfg.junkPacketMinSize != 0 { + device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize + if tempAwgType.aSecCfg.junkPacketMinSize != 0 { isASecOn = true } if device.awg.aSecCfg.junkPacketCount > 0 && - tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize { - tempASecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work } - if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { + if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { device.awg.aSecCfg.junkPacketMinSize = 0 device.awg.aSecCfg.junkPacketMaxSize = 1 if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempASecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMaxSize, MaxSegmentSize, err, ) @@ -639,41 +639,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempASecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMaxSize, MaxSegmentSize, ) } - } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { + } else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwgType.aSecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMinSize, err, ) } else { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwgType.aSecCfg.junkPacketMaxSize, + tempAwgType.aSecCfg.junkPacketMinSize, ) } } else { - device.awg.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize } - if tempASecCfg.junkPacketMaxSize != 0 { + if tempAwgType.aSecCfg.junkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { + if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.initPacketJunkSize, + tempAwgType.aSecCfg.initPacketJunkSize, MaxSegmentSize, err, ) @@ -681,24 +681,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.initPacketJunkSize, + tempAwgType.aSecCfg.initPacketJunkSize, MaxSegmentSize, ) } } else { - device.awg.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize } - if tempASecCfg.initPacketJunkSize != 0 { + if tempAwgType.aSecCfg.initPacketJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { + if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.responsePacketJunkSize, + tempAwgType.aSecCfg.responsePacketJunkSize, MaxSegmentSize, err, ) @@ -706,52 +706,52 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.responsePacketJunkSize, + tempAwgType.aSecCfg.responsePacketJunkSize, MaxSegmentSize, ) } } else { - device.awg.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize } - if tempASecCfg.responsePacketJunkSize != 0 { + if tempAwgType.aSecCfg.responsePacketJunkSize != 0 { isASecOn = true } - if tempASecCfg.initPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.initPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader + device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType } - if tempASecCfg.responsePacketMagicHeader > 4 { + if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader + device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType } - if tempASecCfg.underloadPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader + device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempASecCfg.transportPacketMagicHeader > 4 { + if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader + device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") @@ -828,6 +828,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { if device.awg.isASecOn.IsSet() { device.awg.junkCreator, err = NewJunkCreator(device) } + device.awg.luaAdapter = tempAwgType.luaAdapter device.awg.aSecMux.Unlock() return err diff --git a/device/uapi.go b/device/uapi.go index 1bce39b..a2bb21b 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/amnezia-vpn/amneziawg-go/adapter" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -97,6 +98,9 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } + if device.awg.luaAdapter != nil { + sendf("lua_codec=%s", device.awg.luaAdapter.Base64LuaCode()) + } if device.isAdvancedSecurityOn() { if device.awg.aSecCfg.junkPacketCount != 0 { sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) @@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwgTpe := awgType{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwgTpe) if err != nil { return err } @@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwgTpe) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwgTpe) if err != nil { return err } @@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error { switch key { case "private_key": var sk NoisePrivateKey @@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() + case "lua_codec": + device.log.Verbosef("UAPI: Updating lua_codec") + var err error + tempAwgTpe.luaAdapter, err = adapter.NewLua(adapter.LuaParams{ + Base64LuaCode: value, + }) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err) + } case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount + tempAwgTpe.aSecCfg.isSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize + tempAwgTpe.aSecCfg.isSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + tempAwgTpe.aSecCfg.isSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize + tempAwgTpe.aSecCfg.isSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempAwgTpe.aSecCfg.isSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwgTpe.aSecCfg.isSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)