feat: special handshake mechanism

This commit is contained in:
Mark Puha 2025-06-08 16:09:46 +02:00
parent 431b7b1a37
commit a1d8adca48
16 changed files with 532 additions and 218 deletions

View file

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag"
"github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
@ -19,6 +20,41 @@ import (
"github.com/tevino/abool/v2" "github.com/tevino/abool/v2"
) )
type Version uint8
const (
VersionDefault Version = iota
VersionAwg
VersionAwgSpecialHandshake
)
// TODO:
type AtomicVersion struct {
value atomic.Uint32
}
func NewAtomicVersion(v Version) *AtomicVersion {
av := &AtomicVersion{}
av.Store(v)
return av
}
func (av *AtomicVersion) Load() Version {
return Version(av.value.Load())
}
func (av *AtomicVersion) Store(v Version) {
av.value.Store(uint32(v))
}
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
return av.value.CompareAndSwap(uint32(old), uint32(new))
}
func (av *AtomicVersion) Swap(new Version) Version {
return Version(av.value.Swap(uint32(new)))
}
type Device struct { type Device struct {
state struct { state struct {
// state holds the device's state. It is accessed atomically. // state holds the device's state. It is accessed atomically.
@ -92,10 +128,19 @@ type Device struct {
closed chan struct{} closed chan struct{}
log *Logger log *Logger
version Version
awg awg
}
type awg struct {
isASecOn abool.AtomicBool isASecOn abool.AtomicBool
aSecMux sync.RWMutex // TODO: revision the need of the mutex
aSecCfg aSecCfgType aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator junkCreator junkCreator
// TODO: determine if it's on
handshakeHandler junktag.SpecialHandshakeHandler
} }
type aSecCfgType struct { type aSecCfgType struct {
@ -558,55 +603,55 @@ func (device *Device) BindClose() error {
return err return err
} }
func (device *Device) isAdvancedSecurityOn() bool { func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet() return device.awg.isASecOn.IsSet()
} }
func (device *Device) resetProtocol() { func (device *Device) resetProtocol() {
// restore default message type values // restore default message type values
MessageInitiationType = 1 MessageInitiationType = DefaultMessageInitiationType
MessageResponseType = 2 MessageResponseType = DefaultMessageResponseType
MessageCookieReplyType = 3 MessageCookieReplyType = DefaultMessageCookieReplyType
MessageTransportType = 4 MessageTransportType = DefaultMessageTransportType
} }
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { func (device *Device) handlePostConfig(tempAwg *awg) (err error) {
if !tempASecCfg.isSet { if !tempAwg.aSecCfg.isSet {
return err return err
} }
isASecOn := false isASecOn := false
device.aSecMux.Lock() device.awg.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 { if tempAwg.aSecCfg.junkPacketCount < 0 {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative", "JunkPacketCount should be non negative",
) )
} }
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount device.awg.aSecCfg.junkPacketCount = tempAwg.aSecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 { if tempAwg.aSecCfg.junkPacketCount != 0 {
isASecOn = true isASecOn = true
} }
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize device.awg.aSecCfg.junkPacketMinSize = tempAwg.aSecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 { if tempAwg.aSecCfg.junkPacketMinSize != 0 {
isASecOn = true isASecOn = true
} }
if device.aSecCfg.junkPacketCount > 0 && if device.awg.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { tempAwg.aSecCfg.junkPacketMaxSize == tempAwg.aSecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work tempAwg.aSecCfg.junkPacketMaxSize++ // to make rand gen work
} }
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { if tempAwg.aSecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0 device.awg.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1 device.awg.aSecCfg.junkPacketMaxSize = 1
if err != nil { if err != nil {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize, tempAwg.aSecCfg.junkPacketMaxSize,
MaxSegmentSize, MaxSegmentSize,
err, err,
) )
@ -614,41 +659,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize, tempAwg.aSecCfg.junkPacketMaxSize,
MaxSegmentSize, MaxSegmentSize,
) )
} }
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { } else if tempAwg.aSecCfg.junkPacketMaxSize < tempAwg.aSecCfg.junkPacketMinSize {
if err != nil { if err != nil {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w", "maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize, tempAwg.aSecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize, tempAwg.aSecCfg.junkPacketMinSize,
err, err,
) )
} else { } else {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d", "maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize, tempAwg.aSecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize, tempAwg.aSecCfg.junkPacketMinSize,
) )
} }
} else { } else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize device.awg.aSecCfg.junkPacketMaxSize = tempAwg.aSecCfg.junkPacketMaxSize
} }
if tempASecCfg.junkPacketMaxSize != 0 { if tempAwg.aSecCfg.junkPacketMaxSize != 0 {
isASecOn = true isASecOn = true
} }
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { if MessageInitiationSize+tempAwg.aSecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil { if err != nil {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize, tempAwg.aSecCfg.initPacketJunkSize,
MaxSegmentSize, MaxSegmentSize,
err, err,
) )
@ -656,24 +701,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize, tempAwg.aSecCfg.initPacketJunkSize,
MaxSegmentSize, MaxSegmentSize,
) )
} }
} else { } else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize device.awg.aSecCfg.initPacketJunkSize = tempAwg.aSecCfg.initPacketJunkSize
} }
if tempASecCfg.initPacketJunkSize != 0 { if tempAwg.aSecCfg.initPacketJunkSize != 0 {
isASecOn = true isASecOn = true
} }
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { if MessageResponseSize+tempAwg.aSecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil { if err != nil {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize, tempAwg.aSecCfg.responsePacketJunkSize,
MaxSegmentSize, MaxSegmentSize,
err, err,
) )
@ -681,63 +726,64 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
err = ipcErrorf( err = ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize, tempAwg.aSecCfg.responsePacketJunkSize,
MaxSegmentSize, MaxSegmentSize,
) )
} }
} else { } else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize device.awg.aSecCfg.responsePacketJunkSize = tempAwg.aSecCfg.responsePacketJunkSize
} }
if tempASecCfg.responsePacketJunkSize != 0 { if tempAwg.aSecCfg.responsePacketJunkSize != 0 {
isASecOn = true isASecOn = true
} }
if tempASecCfg.initPacketMagicHeader > 4 { if tempAwg.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header") device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader device.awg.aSecCfg.initPacketMagicHeader = tempAwg.aSecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default init type") device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1 MessageInitiationType = DefaultMessageInitiationType
} }
if tempASecCfg.responsePacketMagicHeader > 4 { if tempAwg.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header") device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader device.awg.aSecCfg.responsePacketMagicHeader = tempAwg.aSecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default response type") device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2 MessageResponseType = DefaultMessageResponseType
} }
if tempASecCfg.underloadPacketMagicHeader > 4 { if tempAwg.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header") device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader device.awg.aSecCfg.underloadPacketMagicHeader = tempAwg.aSecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default underload type") device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3 MessageCookieReplyType = DefaultMessageCookieReplyType
} }
if tempASecCfg.transportPacketMagicHeader > 4 { if tempAwg.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader device.awg.aSecCfg.transportPacketMagicHeader = tempAwg.aSecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default transport type") device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4 MessageTransportType = DefaultMessageTransportType
} }
isSameMap := map[uint32]bool{} isSameMap := map[uint32]struct{}{
isSameMap[MessageInitiationType] = true MessageInitiationType: {},
isSameMap[MessageResponseType] = true MessageResponseType: {},
isSameMap[MessageCookieReplyType] = true MessageCookieReplyType: {},
isSameMap[MessageTransportType] = true MessageTransportType: {},
}
// size will be different if same values // size will be different if same values
if len(isSameMap) != 4 { if len(isSameMap) != 4 {
@ -763,8 +809,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
} }
} }
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize { if newInitSize == newResponseSize {
if err != nil { if err != nil {
@ -792,16 +838,23 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
} }
msgTypeToJunkSize = map[uint32]int{ msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize, MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize, MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0, MessageCookieReplyType: 0,
MessageTransportType: 0, MessageTransportType: 0,
} }
} }
device.isASecOn.SetTo(isASecOn) if err := tempAwg.handshakeHandler.Validate(); err == nil {
device.junkCreator, err = NewJunkCreator(device) return ipcErrorf(ipc.IpcErrorInvalid, "handle post config foo validate: %w", err)
device.aSecMux.Unlock() }
device.awg.isASecOn.SetTo(isASecOn)
device.awg.junkCreator, err = NewJunkCreator(device)
device.awg.handshakeHandler = tempAwg.handshakeHandler
// TODO:
device.version = VersionAwgSpecialHandshake
device.awg.aSecMux.Unlock()
return err return err
} }

View file

@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
}) })
} }
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestASecurityTwoDevicePing(t *testing.T) { func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t, true, true) pair := genTestPair(t, true, true)

View file

@ -2,6 +2,7 @@ package junktag
import ( import (
crand "crypto/rand" crand "crypto/rand"
"encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strconv" "strconv"
@ -12,17 +13,23 @@ import (
) )
type Generator interface { type Generator interface {
Generate() ([]byte, error) Generate() []byte
Size() int
} }
type newGenerator func(string) (Generator, error) type newGenerator func(string) (Generator, error)
type BytesGenerator struct { type BytesGenerator struct {
value []byte value []byte
size int
} }
func (bg *BytesGenerator) Generate() ([]byte, error) { func (bg *BytesGenerator) Generate() []byte {
return bg.value, nil return bg.value
}
func (bg *BytesGenerator) Size() int {
return bg.size
} }
func newBytesGenerator(param string) (Generator, error) { func newBytesGenerator(param string) (Generator, error) {
@ -37,7 +44,7 @@ func newBytesGenerator(param string) (Generator, error) {
return nil, fmt.Errorf("hexToBytes: %w", err) return nil, fmt.Errorf("hexToBytes: %w", err)
} }
return &BytesGenerator{value: hex}, nil return &BytesGenerator{value: hex, size: len(hex)}, nil
} }
func isHexString(s string) bool { func isHexString(s string) bool {
@ -68,23 +75,30 @@ type RandomPacketGenerator struct {
size int size int
} }
func (rpg *RandomPacketGenerator) Generate() ([]byte, error) { func (rpg *RandomPacketGenerator) Generate() []byte {
junk := make([]byte, rpg.size) junk := make([]byte, rpg.size)
_, err := rpg.cha8Rand.Read(junk) rpg.cha8Rand.Read(junk)
return junk, err return junk
}
func (rpg *RandomPacketGenerator) Size() int {
return rpg.size
} }
func newRandomPacketGenerator(param string) (Generator, error) { func newRandomPacketGenerator(param string) (Generator, error) {
size, err := strconv.Atoi(param) size, err := strconv.Atoi(param)
if err != nil { if err != nil {
return nil, fmt.Errorf("randome packet parse int: %w", err) return nil, fmt.Errorf("random packet parse int: %w", err)
}
if size > 1000 {
return nil, fmt.Errorf("random packet size must be less than 1000")
} }
// TODO: add size check
buf := make([]byte, 32) buf := make([]byte, 32)
_, err = crand.Read(buf) _, err = crand.Read(buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("randome packet crand read: %w", err) return nil, fmt.Errorf("random packet crand read: %w", err)
} }
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
@ -93,8 +107,14 @@ func newRandomPacketGenerator(param string) (Generator, error) {
type TimestampGenerator struct { type TimestampGenerator struct {
} }
func (tg *TimestampGenerator) Generate() ([]byte, error) { func (tg *TimestampGenerator) Generate() []byte {
return time.Now().MarshalBinary() buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
return buf
}
func (tg *TimestampGenerator) Size() int {
return 8
} }
func newTimestampGenerator(param string) (Generator, error) { func newTimestampGenerator(param string) (Generator, error) {
@ -104,3 +124,29 @@ func newTimestampGenerator(param string) (Generator, error) {
return &TimestampGenerator{}, nil return &TimestampGenerator{}, nil
} }
type WaitTimeoutGenerator struct {
waitTimeout time.Duration
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
time.Sleep(wtg.waitTimeout)
return []byte{}
}
func (wtg *WaitTimeoutGenerator) Size() int {
return 0
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
size, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
if size > 5000 {
return nil, fmt.Errorf("timeout size must be less than 5000ms")
}
return &WaitTimeoutGenerator{}, nil
}

View file

@ -66,7 +66,7 @@ func Test_newBytesGenerator(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.NotNil(t, got) require.NotNil(t, got)
gotValues, _ := got.Generate() gotValues := got.Generate()
require.Equal(t, tt.want, gotValues) require.Equal(t, tt.want, gotValues)
}) })
} }
@ -95,6 +95,13 @@ func Test_newRandomPacketGenerator(t *testing.T) {
}, },
wantErr: fmt.Errorf("parse int"), wantErr: fmt.Errorf("parse int"),
}, },
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{ {
name: "valid", name: "valid",
args: args{ args: args{
@ -113,11 +120,9 @@ func Test_newRandomPacketGenerator(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.NotNil(t, got) require.NotNil(t, got)
first, err := got.Generate() first := got.Generate()
require.Nil(t, err)
second, err := got.Generate() second := got.Generate()
require.Nil(t, err)
require.NotEqual(t, first, second) require.NotEqual(t, first, second)
}) })
} }

View file

@ -2,36 +2,39 @@ package junktag
import ( import (
"fmt" "fmt"
"maps"
"regexp" "regexp"
"strings" "strings"
) )
type Enum string type EnumTag string
const ( const (
EnumBytes Enum = "b" BytesEnumTag EnumTag = "b"
EnumCounter Enum = "c" CounterEnumTag EnumTag = "c"
EnumTimestamp Enum = "t" TimestampEnumTag EnumTag = "t"
EnumRandomBytes Enum = "r" RandomBytesEnumTag EnumTag = "r"
EnumWaitTimeout Enum = "wt" WaitTimeoutEnumTag EnumTag = "wt"
EnumWaitResponse Enum = "wr" WaitResponseEnumTag EnumTag = "wr"
) )
var validEnum = map[Enum]newGenerator{ var generatorCreator = map[EnumTag]newGenerator{
EnumBytes: newBytesGenerator, BytesEnumTag: newBytesGenerator,
EnumCounter: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, CounterEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
EnumTimestamp: newTimestampGenerator, TimestampEnumTag: newTimestampGenerator,
EnumRandomBytes: newRandomPacketGenerator, RandomBytesEnumTag: newRandomPacketGenerator,
EnumWaitTimeout: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, WaitTimeoutEnumTag: newWaitTimeoutGenerator,
EnumWaitResponse: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
} }
type Foo struct { // helper map to determine enumTags are unique
x []Generator var uniqueTags = map[EnumTag]bool{
CounterEnumTag: false,
TimestampEnumTag: false,
} }
type Tag struct { type Tag struct {
Name Enum Name EnumTag
Param string Param string
} }
@ -41,7 +44,7 @@ func parseTag(input string) (Tag, error) {
match := re.FindStringSubmatch(input) match := re.FindStringSubmatch(input)
tag := Tag{ tag := Tag{
Name: Enum(match[1]), Name: EnumTag(match[1]),
} }
if len(match) > 2 && match[2] != "" { if len(match) > 2 && match[2] != "" {
tag.Param = strings.TrimSpace(match[2]) tag.Param = strings.TrimSpace(match[2])
@ -50,35 +53,43 @@ func parseTag(input string) (Tag, error) {
return tag, nil return tag, nil
} }
func Parse(input string) (Foo, error) { // TODO: pointernes
func Parse(name, input string) (TaggedJunkGenerator, error) {
inputSlice := strings.Split(input, "<") inputSlice := strings.Split(input, "<")
fmt.Printf("%v\n", inputSlice)
if len(inputSlice) <= 1 { if len(inputSlice) <= 1 {
return Foo{}, fmt.Errorf("empty input: %s", input) return TaggedJunkGenerator{}, fmt.Errorf("empty input: %s", input)
} }
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
maps.Copy(uniqueTagCheck, uniqueTags)
// skip byproduct of split // skip byproduct of split
inputSlice = inputSlice[1:] inputSlice = inputSlice[1:]
rv := Foo{x: make([]Generator, 0, len(inputSlice))} rv := newTagedJunkGenerator(name, len(inputSlice))
for _, inputParam := range inputSlice { for _, inputParam := range inputSlice {
if len(inputParam) <= 1 { if len(inputParam) <= 1 {
return Foo{}, fmt.Errorf("empty tag in input: %s", inputSlice) return TaggedJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
} else if strings.Count(inputParam, ">") != 1 { } else if strings.Count(inputParam, ">") != 1 {
return Foo{}, fmt.Errorf("ill formated input: %s", input) return TaggedJunkGenerator{}, fmt.Errorf("ill formated input: %s", input)
} }
tag, _ := parseTag(inputParam) tag, _ := parseTag(inputParam)
fmt.Printf("Tag: %s, Param: %s\n", tag.Name, tag.Param) creator, ok := generatorCreator[tag.Name]
gen, ok := validEnum[tag.Name]
if !ok { if !ok {
return Foo{}, fmt.Errorf("invalid tag: %s", tag.Name) return TaggedJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
} }
generator, err := gen(tag.Param) if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
return TaggedJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
}
uniqueTagCheck[tag.Name] = true
}
generator, err := creator(tag.Param)
if err != nil { if err != nil {
return Foo{}, fmt.Errorf("gen: %w", err) return TaggedJunkGenerator{}, fmt.Errorf("gen: %w", err)
} }
rv.x = append(rv.x, generator)
rv.append(generator)
} }
return rv, nil return rv, nil

View file

@ -41,6 +41,16 @@ func TestParse(t *testing.T) {
args: args{input: "<q 0xf6ab3267fa>"}, args: args{input: "<q 0xf6ab3267fa>"},
wantErr: fmt.Errorf("invalid tag"), wantErr: fmt.Errorf("invalid tag"),
}, },
{
name: "counter uniqueness violation",
args: args{input: "<c><c>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "timestamp uniqueness violation",
args: args{input: "<t><t>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{ {
name: "valid", name: "valid",
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"}, args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},

View file

@ -0,0 +1,48 @@
package junktag
import (
"errors"
"time"
)
type SpecialHandshakeHandler struct {
SpecialJunk TaggedJunkGeneratorHandler
ControlledJunk TaggedJunkGeneratorHandler
nextItime time.Time
ITimeout time.Duration // seconds
// TODO: maybe atomic?
PacketCounter uint64
}
func (handler *SpecialHandshakeHandler) Validate() error {
var errs []error
if err := handler.SpecialJunk.Validate(); err != nil {
errs = append(errs, err)
}
if err := handler.ControlledJunk.Validate(); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
// TODO: distiungish between first and the rest of the packets
if !handler.isTimeToSendSpecial() {
return nil
}
rv := handler.SpecialJunk.Generate()
handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout))
return rv
}
func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
return time.Now().After(handler.nextItime)
}
func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte {
return handler.ControlledJunk.Generate()
}

View file

@ -0,0 +1,42 @@
package junktag
import (
"fmt"
"strconv"
)
type TaggedJunkGenerator struct {
name string
packetSize int
generators []Generator
}
func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator {
return TaggedJunkGenerator{name: name, generators: make([]Generator, size)}
}
func (tg *TaggedJunkGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TaggedJunkGenerator) generate() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (t *TaggedJunkGenerator) nameIndex() (int, error) {
if len(t.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", t.name)
}
index, err := strconv.Atoi(t.name[1:2])
if err != nil {
return 0, fmt.Errorf("name should be 2 char long: %w", err)
}
return index, nil
}

View file

@ -0,0 +1,43 @@
package junktag
import "fmt"
type TaggedJunkGeneratorHandler struct {
generators []TaggedJunkGenerator
length int
}
func (handler *TaggedJunkGeneratorHandler) AppendGenerator(generators TaggedJunkGenerator) {
handler.generators = append(handler.generators, generators)
handler.length++
}
// validate that packets were defined consecutively
func (handler *TaggedJunkGeneratorHandler) Validate() error {
seen := make([]bool, len(handler.generators))
for _, generator := range handler.generators {
if index, err := generator.nameIndex(); err != nil {
return fmt.Errorf("name index: %w", err)
} else {
seen[index-1] = true
}
}
for _, found := range seen {
if !found {
return fmt.Errorf("junk packet index should be consecutive")
}
}
return nil
}
func (handler *TaggedJunkGeneratorHandler) Generate() [][]byte {
var rv = make([][]byte, handler.length)
for i, generator := range handler.generators {
rv[i] = make([]byte, generator.packetSize)
copy(rv[i], generator.generate())
}
return rv
}

View file

@ -22,47 +22,48 @@ func NewJunkCreator(d *Device) (junkCreator, error) {
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) { func (jc *junkCreator) createJunkPackets(junks *[][]byte) error {
if jc.device.aSecCfg.junkPacketCount == 0 { if jc.device.awg.aSecCfg.junkPacketCount == 0 {
return nil, nil return nil
} }
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) *junks = make([][]byte, len(*junks)+jc.device.awg.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { for i := range jc.device.awg.aSecCfg.junkPacketCount {
packetSize := jc.randomPacketSize() packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize) junk, err := jc.randomJunkWithSize(packetSize)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err) return fmt.Errorf("create junk packet: %v", err)
} }
junks = append(junks, junk) (*junks)[i] = junk
} }
return junks, nil return nil
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int { func (jc *junkCreator) randomPacketSize() int {
return int( return int(
jc.cha8Rand.Uint64()%uint64( jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize,
), ),
) + jc.device.aSecCfg.junkPacketMinSize ) + jc.device.awg.aSecCfg.junkPacketMinSize
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size) headerJunk, err := jc.randomJunkWithSize(size)
if err != nil { if err != nil {
return fmt.Errorf("failed to create header junk: %v", err) return fmt.Errorf("create header junk: %v", err)
} }
_, err = writer.Write(headerJunk) _, err = writer.Write(headerJunk)
if err != nil { if err != nil {
return fmt.Errorf("failed to write header junk: %v", err) return fmt.Errorf("write header junk: %v", err)
} }
return nil return nil
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
// TODO: use a memory pool to allocate
junk := make([]byte, size) junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk) _, err := jc.cha8Rand.Read(junk)
return junk, err return junk, err

View file

@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
} }
for range [30]struct{}{} { for range [30]struct{}{} {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize { got > jc.device.awg.aSecCfg.junkPacketMaxSize {
t.Errorf( t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]", "junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got, got,
jc.device.aSecCfg.junkPacketMinSize, jc.device.awg.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize, jc.device.awg.aSecCfg.junkPacketMaxSize,
) )
} }
}) })

View file

@ -52,11 +52,18 @@ const (
WGLabelCookie = "cookie--" WGLabelCookie = "cookie--"
) )
const (
DefaultMessageInitiationType uint32 = 1
DefaultMessageResponseType uint32 = 2
DefaultMessageCookieReplyType uint32 = 3
DefaultMessageTransportType uint32 = 4
)
var ( var (
MessageInitiationType uint32 = 1 MessageInitiationType uint32 = DefaultMessageInitiationType
MessageResponseType uint32 = 2 MessageResponseType uint32 = DefaultMessageResponseType
MessageCookieReplyType uint32 = 3 MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
MessageTransportType uint32 = 4 MessageTransportType uint32 = DefaultMessageTransportType
) )
const ( const (
@ -197,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock() device.awg.aSecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
@ -256,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.aSecMux.RLock() device.awg.aSecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -376,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.aSecMux.RLock() device.awg.aSecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -428,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock() device.awg.aSecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver

View file

@ -137,6 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil { if err == nil {
var totalLen uint64 var totalLen uint64
for _, b := range buffers { for _, b := range buffers {
peer.device.awg.foo.PacketCounter++
totalLen += uint64(len(b)) totalLen += uint64(len(b))
} }
peer.txBytes.Add(totalLen) peer.txBytes.Add(totalLen)

View file

@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
} }
deathSpiral = 0 deathSpiral = 0
device.aSecMux.RLock() device.awg.aSecMux.RLock()
// handle each packet in the batch // handle each packet in the batch
for i, size := range sizes[:count] { for i, size := range sizes[:count] {
if size < MinMessageSize { if size < MinMessageSize {
@ -149,13 +149,14 @@ func (device *Device) RoutineReceiveIncoming(
if msgType == assumedMsgType { if msgType == assumedMsgType {
packet = packet[junkSize:] packet = packet[junkSize:]
} else { } else {
device.log.Verbosef("Transport packet lined up with another msg type") device.log.Verbosef("transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4]) msgType = binary.LittleEndian.Uint32(packet[:4])
} }
} else { } else {
msgType = binary.LittleEndian.Uint32(packet[:4]) msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType { if msgType != MessageTransportType {
device.log.Verbosef("ASec: Received message with unknown type") // probably a junk packet
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
continue continue
} }
} }
@ -245,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming(
default: default:
} }
} }
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer { for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() { if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer peer.queue.inbound.c <- elemsContainer
@ -304,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
device.aSecMux.RLock() device.awg.aSecMux.RLock()
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
@ -456,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive() peer.SendKeepalive()
} }
skip: skip:
device.aSecMux.RUnlock() device.awg.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
} }
} }

View file

@ -126,10 +126,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var sendBuffer [][]byte var sendBuffer [][]byte
// so only packet processed for cookie generation // so only packet processed for cookie generation
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock() if peer.device.version >= VersionAwg {
junks, err := peer.device.junkCreator.createJunkPackets() junks := [][]byte{}
peer.device.aSecMux.RUnlock() if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.aSecMux.RLock()
// set junks depending on packet type
junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk()
if junks == nil {
junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk()
}
peer.device.awg.aSecMux.RUnlock()
}
peer.device.awg.aSecMux.RLock()
err := peer.device.awg.junkCreator.createJunkPackets(&junks)
peer.device.awg.aSecMux.RUnlock()
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
@ -145,19 +156,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} }
} }
peer.device.aSecMux.RLock() peer.device.awg.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 { if peer.device.awg.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock() peer.device.awg.aSecMux.RUnlock()
return err return err
} }
junkedHeader = writer.Bytes() junkedHeader = writer.Bytes()
} }
peer.device.aSecMux.RUnlock() peer.device.awg.aSecMux.RUnlock()
} }
var buf [MessageInitiationSize]byte var buf [MessageInitiationSize]byte
@ -195,19 +206,19 @@ func (peer *Peer) SendHandshakeResponse() error {
} }
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock() peer.device.awg.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 { if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize)
if err != nil { if err != nil {
peer.device.aSecMux.RUnlock() peer.device.awg.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
} }
junkedHeader = writer.Bytes() junkedHeader = writer.Bytes()
} }
peer.device.aSecMux.RUnlock() peer.device.awg.aSecMux.RUnlock()
} }
var buf [MessageResponseSize]byte var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])

View file

@ -18,6 +18,7 @@ import (
"sync" "sync"
"time" "time"
junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag"
"github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
) )
@ -98,32 +99,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
} }
if device.isAdvancedSecurityOn() { if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 { if device.awg.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount) sendf("jc=%d", device.awg.aSecCfg.junkPacketCount)
} }
if device.aSecCfg.junkPacketMinSize != 0 { if device.awg.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize)
} }
if device.aSecCfg.junkPacketMaxSize != 0 { if device.awg.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize)
} }
if device.aSecCfg.initPacketJunkSize != 0 { if device.awg.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize) sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize)
} }
if device.aSecCfg.responsePacketJunkSize != 0 { if device.awg.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize)
} }
if device.aSecCfg.initPacketMagicHeader != 0 { if device.awg.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader)
} }
if device.aSecCfg.responsePacketMagicHeader != 0 { if device.awg.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader)
} }
if device.aSecCfg.underloadPacketMagicHeader != 0 { if device.awg.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader)
} }
if device.aSecCfg.transportPacketMagicHeader != 0 { if device.awg.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader)
} }
} }
@ -180,13 +181,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer) peer := new(ipcSetPeer)
deviceConfig := true deviceConfig := true
tempASecCfg := aSecCfgType{} tempAwg := awg{}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
// Blank line means terminate operation. // Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg) err := device.handlePostConfig(&tempAwg)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +218,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error var err error
if deviceConfig { if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg) err = device.handleDeviceLine(key, value, &tempAwg)
} else { } else {
err = device.handlePeerLine(peer, key, value) err = device.handlePeerLine(peer, key, value)
} }
@ -225,7 +226,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err return err
} }
} }
err = device.handlePostConfig(&tempASecCfg) err = device.handlePostConfig(&tempAwg)
if err != nil { if err != nil {
return err return err
} }
@ -237,7 +238,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil return nil
} }
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
@ -286,80 +287,113 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
case "jc": case "jc":
junkPacketCount, err := strconv.Atoi(value) junkPacketCount, err := strconv.Atoi(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_count") device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount tempAwg.aSecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "jmin": case "jmin":
junkPacketMinSize, err := strconv.Atoi(value) junkPacketMinSize, err := strconv.Atoi(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_min_size") device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "jmax": case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value) junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_max_size") device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "s1": case "s1":
initPacketJunkSize, err := strconv.Atoi(value) initPacketJunkSize, err := strconv.Atoi(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating init_packet_junk_size") device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "s2": case "s2":
responsePacketJunkSize, err := strconv.Atoi(value) responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating response_packet_junk_size") device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "h1": case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
} }
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "h2": case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
} }
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "h3": case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
} }
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true tempAwg.aSecCfg.isSet = true
case "h4": case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
}
tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwg.aSecCfg.isSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key)
} }
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
generators, err := junktag.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.aSecCfg.isSet = true
case "j1", "j2", "j3":
if len(value) == 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key)
}
generators, err := junktag.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators)
tempAwg.aSecCfg.isSet = true
case "itime":
itime, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
}
device.log.Verbosef("UAPI: Updating itime %s", itime)
tempAwg.handshakeHandler.ITimeout = time.Duration(itime)
tempAwg.aSecCfg.isSet = true
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
} }