Merge pull request #2 from marko1777/awg-2-alpha

Adding lua encode integration
This commit is contained in:
Mark Puha 2025-02-11 10:20:11 +01:00 committed by GitHub
commit 1a1e21e808
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 434 additions and 156 deletions

View file

@ -17,7 +17,7 @@ generate-version-and-build:
@$(MAKE) amneziawg-go
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@"
go build -tags luajit -v -o "$@"
install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
@ -92,11 +93,16 @@ type Device struct {
closed chan struct{}
log *Logger
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
awg awgType
}
type awgType struct {
isASecOn abool.AtomicBool
mutex sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
codec *adapter.Lua
}
type aSecCfgType struct {
@ -428,6 +434,10 @@ func (device *Device) Close() {
device.resetProtocol()
if device.awg.codec != nil {
device.awg.codec.Close()
device.awg.codec = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
}
@ -574,54 +584,56 @@ func (device *Device) BindClose() error {
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
return device.awg.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
MessageInitiationType = DefaultMessageInitiationType
MessageResponseType = DefaultMessageResponseType
MessageCookieReplyType = DefaultMessageCookieReplyType
MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
device.awg.codec = tempAwgType.codec
if !tempAwgType.aSecCfg.isSet {
return nil
}
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
device.awg.mutex.Lock()
if tempAwgType.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
device.awg.aSecCfg.junkPacketCount = tempAwgType.aSecCfg.junkPacketCount
if tempAwgType.aSecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
device.awg.aSecCfg.junkPacketMinSize = tempAwgType.aSecCfg.junkPacketMinSize
if tempAwgType.aSecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
if device.awg.aSecCfg.junkPacketCount > 0 &&
tempAwgType.aSecCfg.junkPacketMaxSize == tempAwgType.aSecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
tempAwgType.aSecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if tempAwgType.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.awg.aSecCfg.junkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
@ -629,41 +641,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
} else if tempAwgType.aSecCfg.junkPacketMaxSize < tempAwgType.aSecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
tempAwgType.aSecCfg.junkPacketMaxSize,
tempAwgType.aSecCfg.junkPacketMinSize,
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
device.awg.aSecCfg.junkPacketMaxSize = tempAwgType.aSecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
if tempAwgType.aSecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if MessageInitiationSize+tempAwgType.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
@ -671,24 +683,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
tempAwgType.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
device.awg.aSecCfg.initPacketJunkSize = tempAwgType.aSecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
if tempAwgType.aSecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if MessageResponseSize+tempAwgType.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
@ -696,56 +708,56 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
tempAwgType.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
device.awg.aSecCfg.responsePacketJunkSize = tempAwgType.aSecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
if tempAwgType.aSecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
device.awg.aSecCfg.initPacketMagicHeader = tempAwgType.aSecCfg.initPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1
MessageInitiationType = DefaultMessageInitiationType
}
if tempASecCfg.responsePacketMagicHeader > 4 {
if tempAwgType.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
device.awg.aSecCfg.responsePacketMagicHeader = tempAwgType.aSecCfg.responsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2
MessageResponseType = DefaultMessageResponseType
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
device.awg.aSecCfg.underloadPacketMagicHeader = tempAwgType.aSecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3
MessageCookieReplyType = DefaultMessageCookieReplyType
}
if tempASecCfg.transportPacketMagicHeader > 4 {
if tempAwgType.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
device.awg.aSecCfg.transportPacketMagicHeader = tempAwgType.aSecCfg.transportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4
MessageTransportType = DefaultMessageTransportType
}
isSameMap := map[uint32]bool{}
@ -778,8 +790,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
@ -807,16 +819,34 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
device.awg.isASecOn.SetTo(isASecOn)
if device.awg.isASecOn.IsSet() {
device.awg.junkCreator, err = NewJunkCreator(device)
}
device.awg.mutex.Unlock()
return err
}
func (device *Device) isCodecActive() bool {
return device.awg.codec != nil
}
func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
if device.isCodecActive() {
var err error
packet, err = device.awg.codec.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
}
}
return packet, nil
}

View file

@ -103,10 +103,22 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
/*
head = "headheadhead"
tail = "tailtailtail"
function d_gen(msg_type, data, counter)
return head .. data .. tail
end
function d_parse(data)
return string.match(data, head.. "(.-)".. tail)
end
*/
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@ -114,8 +126,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"h4", "32345",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
@ -129,6 +141,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"lua_codec", "aGVhZCA9ICJoZWFkaGVhZGhlYWQiCnRhaWwgPSAidGFpbHRhaWx0YWlsIgpmdW5jdGlvbiBkX2dlbihtc2dfdHlwZSwgZGF0YSwgY291bnRlcikKCXJldHVybiBoZWFkIC4uIGRhdGEgLi4gdGFpbAplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKICAgICAgICByZXR1cm4gc3RyaW5nLm1hdGNoKGRhdGEsIGhlYWQuLiAiKC4tKSIuLiB0YWlsKQplbmQK",
"jc", "5",
"jmin", "500",
"jmax", "1000",
@ -136,8 +149,8 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"h4", "32345",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
@ -275,7 +288,7 @@ func TestTwoDevicePing(t *testing.T) {
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestTwoDevicePingASecurity(t *testing.T) {
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {

View file

@ -0,0 +1,100 @@
package adapter
import (
"encoding/base64"
"fmt"
"sync/atomic"
"github.com/aarzilli/golua/lua"
)
type Lua struct {
generateState *lua.State
parseState *lua.State
packetCounter atomic.Int64
base64LuaCode string
}
type LuaParams struct {
Base64LuaCode string
}
func NewLua(params LuaParams) (*Lua, error) {
luaCode, err := base64.StdEncoding.DecodeString(params.Base64LuaCode)
if err != nil {
return nil, err
}
strLuaCode := string(luaCode)
// fmt.Println(strLuaCode)
generateState, err := initState(strLuaCode)
if err != nil {
return nil, err
}
parseState, err := initState(strLuaCode)
if err != nil {
return nil, err
}
return &Lua{
generateState: generateState,
parseState: parseState,
base64LuaCode: params.Base64LuaCode,
}, nil
}
func initState(luaCode string) (*lua.State, error) {
state := lua.NewState()
state.OpenLibs()
if err := state.DoString(string(luaCode)); err != nil {
return nil, fmt.Errorf("Error loading Lua code: %v\n", err)
}
return state, nil
}
func (l *Lua) Close() {
l.generateState.Close()
l.parseState.Close()
}
// Only thread safe if used by wg packet creation which happens independably
func (l *Lua) Generate(
msgType int64,
data []byte,
) ([]byte, error) {
l.generateState.GetGlobal("d_gen")
l.generateState.PushInteger(msgType)
l.generateState.PushBytes(data)
l.generateState.PushInteger(l.packetCounter.Add(1))
if err := l.generateState.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := l.generateState.ToBytes(-1)
l.generateState.Pop(1)
return result, nil
}
// Only thread safe if used by wg packet receive which happens independably
func (l *Lua) Parse(data []byte) ([]byte, error) {
l.parseState.GetGlobal("d_parse")
l.parseState.PushBytes(data)
if err := l.parseState.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
result := l.parseState.ToBytes(-1)
l.parseState.Pop(1)
return result, nil
}
func (l *Lua) Base64LuaCode() string {
return l.base64LuaCode
}

View file

@ -0,0 +1,82 @@
package adapter
import (
"testing"
)
func newLua() *Lua {
lua, _ := NewLua(LuaParams{
/*
function d_gen(msg_type, data, counter)
local header = "header"
return counter .. header .. data
end
function d_parse(data)
local header = "1header"
return string.sub(data, #header+1)
end
*/
Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjFoZWFkZXIiCglyZXR1cm4gc3RyaW5nLnN1YihkYXRhLCAjaGVhZGVyKzEpCmVuZAo=",
})
return lua
}
func TestLua_Generate(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
got, err := l.Generate(1, []byte("test"))
if err != nil {
t.Errorf(
"Lua.Generate() error = %v, wantErr %v",
err,
nil,
)
return
}
want := "1headertest"
if string(got) != want {
t.Errorf("Lua.Generate() = %v, want %v", string(got), want)
}
})
}
func TestLua_Parse(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
got, err := l.Parse([]byte("1headertest"))
if err != nil {
t.Errorf("Lua.Parse() error = %v, wantErr %v", err, nil)
return
}
want := "test"
if string(got) != want {
t.Errorf("Lua.Parse() = %v, want %v", got, want)
}
})
}
var R []byte
var l = newLua()
func BenchmarkLuaGenerate(b *testing.B) {
for i := 0; i < b.N; i++ { var err error
R, err = l.Generate(1, []byte("test"))
if err != nil {
return
}
}
}
func BenchmarkLuaParse(b *testing.B) {
for i := 0; i < b.N; i++ {
var err error
R, err = l.Parse([]byte("1headertest"))
if err != nil {
return
}
}
}

View file

@ -23,12 +23,12 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
if jc.device.awg.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
junks := make([][]byte, 0, jc.device.awg.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.awg.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
@ -48,9 +48,9 @@ func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
) + jc.device.awg.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked

View file

@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
got > jc.device.awg.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
jc.device.awg.aSecCfg.junkPacketMinSize,
jc.device.awg.aSecCfg.junkPacketMaxSize,
)
}
})

View file

@ -52,11 +52,18 @@ const (
WGLabelCookie = "cookie--"
)
const (
DefaultMessageInitiationType uint32 = 1
DefaultMessageResponseType uint32 = 2
DefaultMessageCookieReplyType uint32 = 3
DefaultMessageTransportType uint32 = 4
)
var (
MessageInitiationType uint32 = 1
MessageResponseType uint32 = 2
MessageCookieReplyType uint32 = 3
MessageTransportType uint32 = 4
MessageInitiationType uint32 = DefaultMessageInitiationType
MessageResponseType uint32 = DefaultMessageResponseType
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
MessageTransportType uint32 = DefaultMessageTransportType
)
const (
@ -197,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
device.awg.mutex.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -256,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.aSecMux.RLock()
device.awg.mutex.RLock()
if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
return nil
}
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -376,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.aSecMux.RLock()
device.awg.mutex.RLock()
msg.Type = MessageResponseType
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -428,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
device.awg.mutex.RLock()
if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
return nil
}
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
// lookup handshake by receiver

View file

@ -129,23 +129,34 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.aSecMux.RLock()
device.awg.mutex.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
continue
}
// check size of packet
packet := bufsArrs[i][:size]
if device.isCodecActive() {
realPacket, err := device.awg.codec.Parse(packet)
copy(packet, realPacket)
size = len(realPacket)
packet = bufsArrs[i][:size]
if err != nil {
device.log.Verbosef(
"Couldn't parse message; reason: %v",
err,
)
continue
}
}
var msgType uint32
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
@ -245,7 +256,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@ -285,6 +296,7 @@ func (device *Device) RoutineDecryption(id int) {
content,
nil,
)
if err != nil {
elem.packet = nil
}
@ -304,7 +316,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
device.awg.mutex.RLock()
// handle cookie fields and ratelimiting
@ -456,7 +468,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.aSecMux.RUnlock()
device.awg.mutex.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -127,13 +127,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets(peer)
peer.device.aSecMux.RUnlock()
peer.device.awg.mutex.RLock()
junks, err := peer.device.awg.junkCreator.createJunkPackets(peer)
peer.device.awg.mutex.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@ -153,19 +154,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
}
var buf [MessageInitiationSize]byte
@ -175,6 +176,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
@ -211,19 +216,19 @@ func (peer *Peer) SendHandshakeResponse() error {
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
peer.device.awg.mutex.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
peer.device.awg.mutex.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@ -233,6 +238,10 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacketIfActive(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
err = peer.BeginSymmetricSession()
if err != nil {
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
@ -277,8 +286,13 @@ func (device *Device) SendHandshakeCookie(
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
if packet, err = device.codecPacketIfActive(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
return nil
}
@ -578,6 +592,10 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
var err error
if elem.packet, err = device.codecPacketIfActive(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}
elemsContainer.Unlock()
}

View file

@ -18,6 +18,7 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/device/internal/adapter"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@ -97,33 +98,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.awg.codec != nil {
sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
if device.awg.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
if device.awg.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
if device.awg.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
if device.awg.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
if device.awg.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
}
}
@ -180,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempASecCfg := aSecCfgType{}
tempAwgTpe := awgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
err := device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@ -217,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg)
err = device.handleDeviceLine(key, value, &tempAwgTpe)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -225,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempASecCfg)
err = device.handlePostConfig(&tempAwgTpe)
if err != nil {
return err
}
@ -237,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -283,14 +287,23 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
var err error
tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
Base64LuaCode: value,
})
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
}
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
tempAwgTpe.aSecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@ -298,8 +311,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
tempAwgTpe.aSecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@ -307,8 +320,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
tempAwgTpe.aSecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@ -316,8 +329,8 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
tempAwgTpe.aSecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@ -325,40 +338,40 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
tempAwgTpe.aSecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwgTpe.aSecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)

1
go.mod
View file

@ -3,6 +3,7 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.23
require (
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.21.0

2
go.sum
View file

@ -1,3 +1,5 @@
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e h1:ibMKBskN7uMCz9TJgfaIVYVdPyckXm0UjFDRSNV7XB0=
github.com/aarzilli/golua v0.0.0-20241229084300-cd31ab23902e/go.mod h1:hMjfaJVSqVnxenMlsxrq3Ni+vrm9Hs64tU4M7dhUoO4=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=