mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-16 22:16:55 +02:00
Merge pull request #2 from marko1777/awg-2-alpha
Adding lua encode integration
This commit is contained in:
commit
1a1e21e808
13 changed files with 434 additions and 156 deletions
2
Makefile
2
Makefile
|
@ -17,7 +17,7 @@ generate-version-and-build:
|
|||
@$(MAKE) amneziawg-go
|
||||
|
||||
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
go build -v -o "$@"
|
||||
go build -tags luajit -v -o "$@"
|
||||
|
||||
install: amneziawg-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
||||
|
|
160
device/device.go
160
device/device.go
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
|
@ -92,11 +93,16 @@ type Device struct {
|
|||
closed chan struct{}
|
||||
log *Logger
|
||||
|
||||
isASecOn abool.AtomicBool
|
||||
aSecMux sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
awg awgType
|
||||
}
|
||||
|
||||
type awgType struct {
|
||||
isASecOn abool.AtomicBool
|
||||
mutex sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
junkCreator junkCreator
|
||||
|
||||
codec *adapter.Lua
|
||||
}
|
||||
|
||||
type aSecCfgType struct {
|
||||
|
@ -428,6 +434,10 @@ func (device *Device) Close() {
|
|||
|
||||
device.resetProtocol()
|
||||
|
||||
if device.awg.codec != nil {
|
||||
device.awg.codec.Close()
|
||||
device.awg.codec = nil
|
||||
}
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
|
@ -574,54 +584,56 @@ func (device *Device) BindClose() error {
|
|||
return err
|
||||
}
|
||||
func (device *Device) isAdvancedSecurityOn() bool {
|
||||
return device.isASecOn.IsSet()
|
||||
return device.awg.isASecOn.IsSet()
|
||||
}
|
||||
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||
if !tempASecCfg.isSet {
|
||||
return err
|
||||
func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
|
||||
device.awg.codec = tempAwgType.codec
|
||||
|
||||
if !tempAwgType.aSecCfg.isSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
isASecOn := false
|
||||
device.aSecMux.Lock()
|
||||
if tempASecCfg.junkPacketCount < 0 {
|
||||
device.awg.mutex.Lock()
|
||||
if tempAwgType.aSecCfg.junkPacketCount < 0 {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketCount should be non negative",
|
||||
)
|
||||
}
|
||||
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
||||
if tempASecCfg.junkPacketCount != 0 {
|
||||
device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount
|
||||
if tempAwgType.aSecCfg.junkPacketCount != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
||||
if tempASecCfg.junkPacketMinSize != 0 {
|
||||
device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize
|
||||
if tempAwgType.aSecCfg.junkPacketMinSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if device.aSecCfg.junkPacketCount > 0 &&
|
||||
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
||||
if device.awg.aSecCfg.junkPacketCount > 0 &&
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize {
|
||||
|
||||
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||
}
|
||||
|
||||
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||
device.aSecCfg.junkPacketMinSize = 0
|
||||
device.aSecCfg.junkPacketMaxSize = 1
|
||||
if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||
device.awg.aSecCfg.junkPacketMinSize = 0
|
||||
device.awg.aSecCfg.junkPacketMaxSize = 1
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
|
@ -629,41 +641,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
)
|
||||
}
|
||||
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
||||
} else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize,
|
||||
tempAwgType.aSecCfg.junkPacketMinSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
tempAwgType.aSecCfg.junkPacketMaxSize,
|
||||
tempAwgType.aSecCfg.junkPacketMinSize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
||||
device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize
|
||||
}
|
||||
|
||||
if tempASecCfg.junkPacketMaxSize != 0 {
|
||||
if tempAwgType.aSecCfg.junkPacketMaxSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||
if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
tempAwgType.aSecCfg.initPacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
|
@ -671,24 +683,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
tempAwgType.aSecCfg.initPacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
||||
device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize
|
||||
}
|
||||
|
||||
if tempASecCfg.initPacketJunkSize != 0 {
|
||||
if tempAwgType.aSecCfg.initPacketJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||
if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
tempAwgType.aSecCfg.responsePacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
|
@ -696,56 +708,56 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
tempAwgType.aSecCfg.responsePacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
||||
device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize
|
||||
}
|
||||
|
||||
if tempASecCfg.responsePacketJunkSize != 0 {
|
||||
if tempAwgType.aSecCfg.responsePacketJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if tempASecCfg.initPacketMagicHeader > 4 {
|
||||
if tempAwgType.aSecCfg.initPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
|
||||
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||
device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader
|
||||
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = 1
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
}
|
||||
|
||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||
if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
|
||||
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||
device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader
|
||||
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = 2
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
}
|
||||
|
||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||
if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||
device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = 3
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
}
|
||||
|
||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||
if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
|
||||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||
device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader
|
||||
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = 4
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
isSameMap := map[uint32]bool{}
|
||||
|
@ -778,8 +790,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
||||
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
||||
newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
|
||||
newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
|
||||
|
||||
if newInitSize == newResponseSize {
|
||||
if err != nil {
|
||||
|
@ -807,16 +819,34 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
}
|
||||
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
||||
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
||||
MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
|
||||
MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
|
||||
MessageCookieReplyType: 0,
|
||||
MessageTransportType: 0,
|
||||
}
|
||||
}
|
||||
|
||||
device.isASecOn.SetTo(isASecOn)
|
||||
device.junkCreator, err = NewJunkCreator(device)
|
||||
device.aSecMux.Unlock()
|
||||
device.awg.isASecOn.SetTo(isASecOn)
|
||||
if device.awg.isASecOn.IsSet() {
|
||||
device.awg.junkCreator, err = NewJunkCreator(device)
|
||||
}
|
||||
device.awg.mutex.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) isCodecActive() bool {
|
||||
return device.awg.codec != nil
|
||||
}
|
||||
|
||||
func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
|
||||
if device.isCodecActive() {
|
||||
var err error
|
||||
packet, err = device.awg.codec.Generate(int64(msgType),packet)
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
|
|
@ -103,10 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
/*
|
||||
head = "headheadhead"
|
||||
tail = "tailtailtail"
|
||||
function d_gen(msg_type, data, counter)
|
||||
return head .. data .. tail
|
||||
end
|
||||
|
||||
function d_parse(data)
|
||||
return string.match(data, head.. "(.-)".. tail)
|
||||
end
|
||||
*/
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
|
@ -114,8 +126,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"h4", "32345",
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
|
@ -129,6 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
|
@ -136,8 +149,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"h4", "32345",
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
|
@ -275,7 +288,7 @@ func TestTwoDevicePing(t *testing.T) {
|
|||
}
|
||||
|
||||
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
||||
func TestTwoDevicePingASecurity(t *testing.T) {
|
||||
func TestASecurityTwoDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true, true)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
|
|
100
device/internal/adapter/lua.go
Normal file
100
device/internal/adapter/lua.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package adapter
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/aarzilli/golua/lua"
|
||||
)
|
||||
|
||||
type Lua struct {
|
||||
generateState *lua.State
|
||||
parseState *lua.State
|
||||
packetCounter atomic.Int64
|
||||
base64LuaCode string
|
||||
}
|
||||
|
||||
type LuaParams struct {
|
||||
Base64LuaCode string
|
||||
}
|
||||
|
||||
func NewLua(params LuaParams) (*Lua, error) {
|
||||
luaCode, err := base64.StdEncoding.DecodeString(params.Base64LuaCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
strLuaCode := string(luaCode)
|
||||
// fmt.Println(strLuaCode)
|
||||
|
||||
generateState, err := initState(strLuaCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parseState, err := initState(strLuaCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Lua{
|
||||
generateState: generateState,
|
||||
parseState: parseState,
|
||||
base64LuaCode: params.Base64LuaCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initState(luaCode string) (*lua.State, error) {
|
||||
state := lua.NewState()
|
||||
state.OpenLibs()
|
||||
|
||||
if err := state.DoString(string(luaCode)); err != nil {
|
||||
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (l *Lua) Close() {
|
||||
l.generateState.Close()
|
||||
l.parseState.Close()
|
||||
}
|
||||
|
||||
// Only thread safe if used by wg packet creation which happens independably
|
||||
func (l *Lua) Generate(
|
||||
msgType int64,
|
||||
data []byte,
|
||||
) ([]byte, error) {
|
||||
l.generateState.GetGlobal("d_gen")
|
||||
|
||||
l.generateState.PushInteger(msgType)
|
||||
l.generateState.PushBytes(data)
|
||||
l.generateState.PushInteger(l.packetCounter.Add(1))
|
||||
|
||||
if err := l.generateState.Call(3, 1); err != nil {
|
||||
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
|
||||
}
|
||||
|
||||
result := l.generateState.ToBytes(-1)
|
||||
l.generateState.Pop(1)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Only thread safe if used by wg packet receive which happens independably
|
||||
func (l *Lua) Parse(data []byte) ([]byte, error) {
|
||||
l.parseState.GetGlobal("d_parse")
|
||||
|
||||
l.parseState.PushBytes(data)
|
||||
if err := l.parseState.Call(1, 1); err != nil {
|
||||
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
|
||||
}
|
||||
|
||||
result := l.parseState.ToBytes(-1)
|
||||
l.parseState.Pop(1)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (l *Lua) Base64LuaCode() string {
|
||||
return l.base64LuaCode
|
||||
}
|
82
device/internal/adapter/lua_test.go
Normal file
82
device/internal/adapter/lua_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package adapter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newLua() *Lua {
|
||||
lua, _ := NewLua(LuaParams{
|
||||
/*
|
||||
function d_gen(msg_type, data, counter)
|
||||
local header = "header"
|
||||
return counter .. header .. data
|
||||
end
|
||||
|
||||
function d_parse(data)
|
||||
local header = "1header"
|
||||
return string.sub(data, #header+1)
|
||||
end
|
||||
*/
|
||||
Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjFoZWFkZXIiCglyZXR1cm4gc3RyaW5nLnN1YihkYXRhLCAjaGVhZGVyKzEpCmVuZAo=",
|
||||
})
|
||||
return lua
|
||||
}
|
||||
|
||||
func TestLua_Generate(t *testing.T) {
|
||||
t.Run("", func(t *testing.T) {
|
||||
l := newLua()
|
||||
defer l.Close()
|
||||
got, err := l.Generate(1, []byte("test"))
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"Lua.Generate() error = %v, wantErr %v",
|
||||
err,
|
||||
nil,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
want := "1headertest"
|
||||
if string(got) != want {
|
||||
t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLua_Parse(t *testing.T) {
|
||||
t.Run("", func(t *testing.T) {
|
||||
l := newLua()
|
||||
defer l.Close()
|
||||
got, err := l.Parse([]byte("1headertest"))
|
||||
if err != nil {
|
||||
t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
|
||||
return
|
||||
}
|
||||
want := "test"
|
||||
if string(got) != want {
|
||||
t.Errorf("Lua.Parse() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
var R []byte
|
||||
var l = newLua()
|
||||
func BenchmarkLuaGenerate(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ { var err error
|
||||
R, err = l.Generate(1, []byte("test"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLuaParse(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var err error
|
||||
R, err = l.Parse([]byte("1headertest"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
|
|||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
|
||||
if jc.device.aSecCfg.junkPacketCount == 0 {
|
||||
if jc.device.awg.aSecCfg.junkPacketCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
|
||||
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
|
||||
junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount)
|
||||
for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk, err := jc.randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
|
@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
|
|||
func (jc *junkCreator) randomPacketSize() int {
|
||||
return int(
|
||||
jc.cha8Rand.Uint64()%uint64(
|
||||
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
|
||||
jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
|
||||
),
|
||||
) + jc.device.aSecCfg.junkPacketMinSize
|
||||
) + jc.device.awg.aSecCfg.junkPacketMinSize
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
|
|
|
@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
|
|||
}
|
||||
for range [30]struct{}{} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
|
||||
got > jc.device.aSecCfg.junkPacketMaxSize {
|
||||
if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
|
||||
got > jc.device.awg.aSecCfg.junkPacketMaxSize {
|
||||
t.Errorf(
|
||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||
got,
|
||||
jc.device.aSecCfg.junkPacketMinSize,
|
||||
jc.device.aSecCfg.junkPacketMaxSize,
|
||||
jc.device.awg.aSecCfg.junkPacketMinSize,
|
||||
jc.device.awg.aSecCfg.junkPacketMaxSize,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -52,11 +52,18 @@ const (
|
|||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = 1
|
||||
MessageResponseType uint32 = 2
|
||||
MessageCookieReplyType uint32 = 3
|
||||
MessageTransportType uint32 = 4
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -197,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
@ -256,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
@ -376,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
var msg MessageResponse
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
msg.Type = MessageResponseType
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
|
@ -428,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
if msg.Type != MessageResponseType {
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
|
|
@ -129,23 +129,34 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
if device.isCodecActive() {
|
||||
realPacket, err := device.awg.codec.Parse(packet)
|
||||
copy(packet, realPacket)
|
||||
size = len(realPacket)
|
||||
packet = bufsArrs[i][:size]
|
||||
if err != nil {
|
||||
device.log.Verbosef(
|
||||
"Couldn't parse message; reason: %v",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
var msgType uint32
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||
// transport size can align with other header types;
|
||||
// making sure we have the right msgType
|
||||
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
|
||||
msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
|
||||
if msgType == assumedMsgType {
|
||||
packet = packet[junkSize:]
|
||||
} else {
|
||||
|
@ -245,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
default:
|
||||
}
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
|
@ -285,6 +296,7 @@ func (device *Device) RoutineDecryption(id int) {
|
|||
content,
|
||||
nil,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
elem.packet = nil
|
||||
}
|
||||
|
@ -304,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.mutex.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
|
@ -456,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.mutex.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
)
|
||||
return err
|
||||
}
|
||||
|
||||
var sendBuffer [][]byte
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
junks, err := peer.device.junkCreator.createJunkPackets(peer)
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RLock()
|
||||
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
|
@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||
peer.device.awg.mutex.RLock()
|
||||
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
|
@ -175,6 +176,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
|
@ -211,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
}
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
peer.device.awg.mutex.RLock()
|
||||
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.awg.mutex.RUnlock()
|
||||
}
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
@ -233,6 +238,10 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
|
@ -277,8 +286,13 @@ func (device *Device) SendHandshakeCookie(
|
|||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet := writer.Bytes()
|
||||
if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -578,6 +592,10 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
var err error
|
||||
if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
)
|
||||
|
||||
|
@ -97,33 +98,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.awg.codec != nil {
|
||||
sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
|
||||
}
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if device.aSecCfg.junkPacketCount != 0 {
|
||||
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
||||
if device.awg.aSecCfg.junkPacketCount != 0 {
|
||||
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
|
||||
}
|
||||
if device.aSecCfg.junkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
||||
if device.awg.aSecCfg.junkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
|
||||
}
|
||||
if device.aSecCfg.junkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
||||
if device.awg.aSecCfg.junkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
|
||||
}
|
||||
if device.aSecCfg.initPacketJunkSize != 0 {
|
||||
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
||||
if device.awg.aSecCfg.initPacketJunkSize != 0 {
|
||||
sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
|
||||
}
|
||||
if device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
||||
if device.awg.aSecCfg.responsePacketJunkSize != 0 {
|
||||
sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
|
||||
}
|
||||
if device.aSecCfg.initPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
||||
if device.awg.aSecCfg.initPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
|
||||
}
|
||||
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
||||
if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
|
||||
}
|
||||
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
||||
if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
|
||||
}
|
||||
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
||||
if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
tempASecCfg := aSecCfgType{}
|
||||
tempAwgTpe := awgType{}
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
err := device.handlePostConfig(&tempASecCfg)
|
||||
err := device.handlePostConfig(&tempAwgTpe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value, &tempASecCfg)
|
||||
err = device.handleDeviceLine(key, value, &tempAwgTpe)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
|
@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
err = device.handlePostConfig(&tempASecCfg)
|
||||
err = device.handlePostConfig(&tempAwgTpe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
|
||||
func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
|
@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
||||
case "lua_codec":
|
||||
device.log.Verbosef("UAPI: Updating lua_codec")
|
||||
var err error
|
||||
tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
|
||||
Base64LuaCode: value,
|
||||
})
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
|
||||
}
|
||||
case "jc":
|
||||
junkPacketCount, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||
tempASecCfg.junkPacketCount = junkPacketCount
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "jmin":
|
||||
junkPacketMinSize, err := strconv.Atoi(value)
|
||||
|
@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "jmax":
|
||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||
|
@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "s1":
|
||||
initPacketJunkSize, err := strconv.Atoi(value)
|
||||
|
@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "s2":
|
||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||
|
@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "h1":
|
||||
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "h2":
|
||||
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "h3":
|
||||
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
case "h4":
|
||||
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempAwgTpe.aSecCfg.isSet = true
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
|
|
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go
|
|||
go 1.23
|
||||
|
||||
require (
|
||||
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e
|
||||
github.com/tevino/abool/v2 v2.1.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/net v0.21.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -1,3 +1,5 @@
|
|||
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e h1:ibMKBskN7uMCz9TJgfaIVYVdPyckXm0UjFDRSNV7XB0=
|
||||
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e/go.mod h1:hMjfaJVSqVnxenMlsxrq3Ni+vrm9Hs64tU4M7dhUoO4=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||
|
|
Loading…
Add table
Reference in a new issue