diff --git a/device/device.go b/device/device.go index 1be15d0..9ba69dd 100644 --- a/device/device.go +++ b/device/device.go @@ -12,6 +12,7 @@ import ( "time" "github.com/amnezia-vpn/amneziawg-go/conn" + junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" @@ -19,6 +20,41 @@ import ( "github.com/tevino/abool/v2" ) +type Version uint8 + +const ( + VersionDefault Version = iota + VersionAwg + VersionAwgSpecialHandshake +) + +// TODO: +type AtomicVersion struct { + value atomic.Uint32 +} + +func NewAtomicVersion(v Version) *AtomicVersion { + av := &AtomicVersion{} + av.Store(v) + return av +} + +func (av *AtomicVersion) Load() Version { + return Version(av.value.Load()) +} + +func (av *AtomicVersion) Store(v Version) { + av.value.Store(uint32(v)) +} + +func (av *AtomicVersion) CompareAndSwap(old, new Version) bool { + return av.value.CompareAndSwap(uint32(old), uint32(new)) +} + +func (av *AtomicVersion) Swap(new Version) Version { + return Version(av.value.Swap(uint32(new))) +} + type Device struct { state struct { // state holds the device's state. It is accessed atomically. @@ -92,10 +128,19 @@ type Device struct { closed chan struct{} log *Logger + version Version + awg awg +} + +type awg struct { isASecOn abool.AtomicBool - aSecMux sync.RWMutex - aSecCfg aSecCfgType + // TODO: revision the need of the mutex + aSecMux sync.RWMutex + aSecCfg aSecCfgType junkCreator junkCreator + + // TODO: determine if it's on + handshakeHandler junktag.SpecialHandshakeHandler } type aSecCfgType struct { @@ -558,55 +603,55 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.isASecOn.IsSet() + return device.awg.isASecOn.IsSet() } func (device *Device) resetProtocol() { // restore default message type values - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 + MessageInitiationType = DefaultMessageInitiationType + MessageResponseType = DefaultMessageResponseType + MessageCookieReplyType = DefaultMessageCookieReplyType + MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { +func (device *Device) handlePostConfig(tempAwg *awg) (err error) { - if !tempASecCfg.isSet { + if !tempAwg.aSecCfg.isSet { return err } isASecOn := false - device.aSecMux.Lock() - if tempASecCfg.junkPacketCount < 0 { + device.awg.aSecMux.Lock() + if tempAwg.aSecCfg.junkPacketCount < 0 { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", ) } - device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount - if tempASecCfg.junkPacketCount != 0 { + device.awg.aSecCfg.junkPacketCount = tempAwg.aSecCfg.junkPacketCount + if tempAwg.aSecCfg.junkPacketCount != 0 { isASecOn = true } - device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize - if tempASecCfg.junkPacketMinSize != 0 { + device.awg.aSecCfg.junkPacketMinSize = tempAwg.aSecCfg.junkPacketMinSize + if tempAwg.aSecCfg.junkPacketMinSize != 0 { isASecOn = true } - if device.aSecCfg.junkPacketCount > 0 && - tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + if device.awg.aSecCfg.junkPacketCount > 0 && + tempAwg.aSecCfg.junkPacketMaxSize == tempAwg.aSecCfg.junkPacketMinSize { - tempASecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwg.aSecCfg.junkPacketMaxSize++ // to make rand gen work } - if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.aSecCfg.junkPacketMinSize = 0 - device.aSecCfg.junkPacketMaxSize = 1 + if tempAwg.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { + device.awg.aSecCfg.junkPacketMinSize = 0 + device.awg.aSecCfg.junkPacketMaxSize = 1 if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempASecCfg.junkPacketMaxSize, + tempAwg.aSecCfg.junkPacketMaxSize, MaxSegmentSize, err, ) @@ -614,41 +659,41 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempASecCfg.junkPacketMaxSize, + tempAwg.aSecCfg.junkPacketMaxSize, MaxSegmentSize, ) } - } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { + } else if tempAwg.aSecCfg.junkPacketMaxSize < tempAwg.aSecCfg.junkPacketMinSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwg.aSecCfg.junkPacketMaxSize, + tempAwg.aSecCfg.junkPacketMinSize, err, ) } else { err = ipcErrorf( ipc.IpcErrorInvalid, "maxSize: %d; should be greater than minSize: %d", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, + tempAwg.aSecCfg.junkPacketMaxSize, + tempAwg.aSecCfg.junkPacketMinSize, ) } } else { - device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.aSecCfg.junkPacketMaxSize = tempAwg.aSecCfg.junkPacketMaxSize } - if tempASecCfg.junkPacketMaxSize != 0 { + if tempAwg.aSecCfg.junkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { + if MessageInitiationSize+tempAwg.aSecCfg.initPacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.initPacketJunkSize, + tempAwg.aSecCfg.initPacketJunkSize, MaxSegmentSize, err, ) @@ -656,24 +701,24 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.initPacketJunkSize, + tempAwg.aSecCfg.initPacketJunkSize, MaxSegmentSize, ) } } else { - device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.aSecCfg.initPacketJunkSize = tempAwg.aSecCfg.initPacketJunkSize } - if tempASecCfg.initPacketJunkSize != 0 { + if tempAwg.aSecCfg.initPacketJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { + if MessageResponseSize+tempAwg.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { if err != nil { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.responsePacketJunkSize, + tempAwg.aSecCfg.responsePacketJunkSize, MaxSegmentSize, err, ) @@ -681,63 +726,64 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { err = ipcErrorf( ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.responsePacketJunkSize, + tempAwg.aSecCfg.responsePacketJunkSize, MaxSegmentSize, ) } } else { - device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.aSecCfg.responsePacketJunkSize = tempAwg.aSecCfg.responsePacketJunkSize } - if tempASecCfg.responsePacketJunkSize != 0 { + if tempAwg.aSecCfg.responsePacketJunkSize != 0 { isASecOn = true } - if tempASecCfg.initPacketMagicHeader > 4 { + if tempAwg.aSecCfg.initPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader - MessageInitiationType = device.aSecCfg.initPacketMagicHeader + device.awg.aSecCfg.initPacketMagicHeader = tempAwg.aSecCfg.initPacketMagicHeader + MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = 1 + MessageInitiationType = DefaultMessageInitiationType } - if tempASecCfg.responsePacketMagicHeader > 4 { + if tempAwg.aSecCfg.responsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader - MessageResponseType = device.aSecCfg.responsePacketMagicHeader + device.awg.aSecCfg.responsePacketMagicHeader = tempAwg.aSecCfg.responsePacketMagicHeader + MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = 2 + MessageResponseType = DefaultMessageResponseType } - if tempASecCfg.underloadPacketMagicHeader > 4 { + if tempAwg.aSecCfg.underloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + device.awg.aSecCfg.underloadPacketMagicHeader = tempAwg.aSecCfg.underloadPacketMagicHeader + MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = 3 + MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempASecCfg.transportPacketMagicHeader > 4 { + if tempAwg.aSecCfg.transportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader - MessageTransportType = device.aSecCfg.transportPacketMagicHeader + device.awg.aSecCfg.transportPacketMagicHeader = tempAwg.aSecCfg.transportPacketMagicHeader + MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = 4 + MessageTransportType = DefaultMessageTransportType } - isSameMap := map[uint32]bool{} - isSameMap[MessageInitiationType] = true - isSameMap[MessageResponseType] = true - isSameMap[MessageCookieReplyType] = true - isSameMap[MessageTransportType] = true + isSameMap := map[uint32]struct{}{ + MessageInitiationType: {}, + MessageResponseType: {}, + MessageCookieReplyType: {}, + MessageTransportType: {}, + } // size will be different if same values if len(isSameMap) != 4 { @@ -763,8 +809,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } } - newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize + newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize + newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize if newInitSize == newResponseSize { if err != nil { @@ -792,16 +838,23 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, + MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, + MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, MessageCookieReplyType: 0, MessageTransportType: 0, } } - device.isASecOn.SetTo(isASecOn) - device.junkCreator, err = NewJunkCreator(device) - device.aSecMux.Unlock() + if err := tempAwg.handshakeHandler.Validate(); err == nil { + return ipcErrorf(ipc.IpcErrorInvalid, "handle post config foo validate: %w", err) + } + device.awg.isASecOn.SetTo(isASecOn) + device.awg.junkCreator, err = NewJunkCreator(device) + device.awg.handshakeHandler = tempAwg.handshakeHandler + // TODO: + device.version = VersionAwgSpecialHandshake + + device.awg.aSecMux.Unlock() return err } diff --git a/device/device_test.go b/device/device_test.go index d03610f..7e735e1 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) { }) } +// Run test with -race=false to avoid the race for setting the default msgTypes 2 times func TestASecurityTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true, true) diff --git a/device/internal/junk-tag/generator.go b/device/internal/junk-tag/generator.go index e664233..8662655 100644 --- a/device/internal/junk-tag/generator.go +++ b/device/internal/junk-tag/generator.go @@ -2,6 +2,7 @@ package junktag import ( crand "crypto/rand" + "encoding/binary" "encoding/hex" "fmt" "strconv" @@ -12,17 +13,23 @@ import ( ) type Generator interface { - Generate() ([]byte, error) + Generate() []byte + Size() int } type newGenerator func(string) (Generator, error) type BytesGenerator struct { value []byte + size int } -func (bg *BytesGenerator) Generate() ([]byte, error) { - return bg.value, nil +func (bg *BytesGenerator) Generate() []byte { + return bg.value +} + +func (bg *BytesGenerator) Size() int { + return bg.size } func newBytesGenerator(param string) (Generator, error) { @@ -37,7 +44,7 @@ func newBytesGenerator(param string) (Generator, error) { return nil, fmt.Errorf("hexToBytes: %w", err) } - return &BytesGenerator{value: hex}, nil + return &BytesGenerator{value: hex, size: len(hex)}, nil } func isHexString(s string) bool { @@ -68,23 +75,30 @@ type RandomPacketGenerator struct { size int } -func (rpg *RandomPacketGenerator) Generate() ([]byte, error) { +func (rpg *RandomPacketGenerator) Generate() []byte { junk := make([]byte, rpg.size) - _, err := rpg.cha8Rand.Read(junk) - return junk, err + rpg.cha8Rand.Read(junk) + return junk +} + +func (rpg *RandomPacketGenerator) Size() int { + return rpg.size } func newRandomPacketGenerator(param string) (Generator, error) { size, err := strconv.Atoi(param) if err != nil { - return nil, fmt.Errorf("randome packet parse int: %w", err) + return nil, fmt.Errorf("random packet parse int: %w", err) + } + + if size > 1000 { + return nil, fmt.Errorf("random packet size must be less than 1000") } - // TODO: add size check buf := make([]byte, 32) _, err = crand.Read(buf) if err != nil { - return nil, fmt.Errorf("randome packet crand read: %w", err) + return nil, fmt.Errorf("random packet crand read: %w", err) } return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil @@ -93,8 +107,14 @@ func newRandomPacketGenerator(param string) (Generator, error) { type TimestampGenerator struct { } -func (tg *TimestampGenerator) Generate() ([]byte, error) { - return time.Now().MarshalBinary() +func (tg *TimestampGenerator) Generate() []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix())) + return buf +} + +func (tg *TimestampGenerator) Size() int { + return 8 } func newTimestampGenerator(param string) (Generator, error) { @@ -104,3 +124,29 @@ func newTimestampGenerator(param string) (Generator, error) { return &TimestampGenerator{}, nil } + +type WaitTimeoutGenerator struct { + waitTimeout time.Duration +} + +func (wtg *WaitTimeoutGenerator) Generate() []byte { + time.Sleep(wtg.waitTimeout) + return []byte{} +} + +func (wtg *WaitTimeoutGenerator) Size() int { + return 0 +} + +func newWaitTimeoutGenerator(param string) (Generator, error) { + size, err := strconv.Atoi(param) + if err != nil { + return nil, fmt.Errorf("timeout parse int: %w", err) + } + + if size > 5000 { + return nil, fmt.Errorf("timeout size must be less than 5000ms") + } + + return &WaitTimeoutGenerator{}, nil +} diff --git a/device/internal/junk-tag/generator_test.go b/device/internal/junk-tag/generator_test.go index 44bf8a1..a80e048 100644 --- a/device/internal/junk-tag/generator_test.go +++ b/device/internal/junk-tag/generator_test.go @@ -66,7 +66,7 @@ func Test_newBytesGenerator(t *testing.T) { require.Nil(t, err) require.NotNil(t, got) - gotValues, _ := got.Generate() + gotValues := got.Generate() require.Equal(t, tt.want, gotValues) }) } @@ -95,6 +95,13 @@ func Test_newRandomPacketGenerator(t *testing.T) { }, wantErr: fmt.Errorf("parse int"), }, + { + name: "too large", + args: args{ + param: "1001", + }, + wantErr: fmt.Errorf("random packet size must be less than 1000"), + }, { name: "valid", args: args{ @@ -113,11 +120,9 @@ func Test_newRandomPacketGenerator(t *testing.T) { require.Nil(t, err) require.NotNil(t, got) - first, err := got.Generate() - require.Nil(t, err) + first := got.Generate() - second, err := got.Generate() - require.Nil(t, err) + second := got.Generate() require.NotEqual(t, first, second) }) } diff --git a/device/internal/junk-tag/parser.go b/device/internal/junk-tag/parser.go index 1d20f85..8230c46 100644 --- a/device/internal/junk-tag/parser.go +++ b/device/internal/junk-tag/parser.go @@ -2,36 +2,39 @@ package junktag import ( "fmt" + "maps" "regexp" "strings" ) -type Enum string +type EnumTag string const ( - EnumBytes Enum = "b" - EnumCounter Enum = "c" - EnumTimestamp Enum = "t" - EnumRandomBytes Enum = "r" - EnumWaitTimeout Enum = "wt" - EnumWaitResponse Enum = "wr" + BytesEnumTag EnumTag = "b" + CounterEnumTag EnumTag = "c" + TimestampEnumTag EnumTag = "t" + RandomBytesEnumTag EnumTag = "r" + WaitTimeoutEnumTag EnumTag = "wt" + WaitResponseEnumTag EnumTag = "wr" ) -var validEnum = map[Enum]newGenerator{ - EnumBytes: newBytesGenerator, - EnumCounter: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, - EnumTimestamp: newTimestampGenerator, - EnumRandomBytes: newRandomPacketGenerator, - EnumWaitTimeout: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, - EnumWaitResponse: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, +var generatorCreator = map[EnumTag]newGenerator{ + BytesEnumTag: newBytesGenerator, + CounterEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, + TimestampEnumTag: newTimestampGenerator, + RandomBytesEnumTag: newRandomPacketGenerator, + WaitTimeoutEnumTag: newWaitTimeoutGenerator, + WaitResponseEnumTag: func(s string) (Generator, error) { return &BytesGenerator{}, nil }, } -type Foo struct { - x []Generator +// helper map to determine enumTags are unique +var uniqueTags = map[EnumTag]bool{ + CounterEnumTag: false, + TimestampEnumTag: false, } type Tag struct { - Name Enum + Name EnumTag Param string } @@ -41,7 +44,7 @@ func parseTag(input string) (Tag, error) { match := re.FindStringSubmatch(input) tag := Tag{ - Name: Enum(match[1]), + Name: EnumTag(match[1]), } if len(match) > 2 && match[2] != "" { tag.Param = strings.TrimSpace(match[2]) @@ -50,35 +53,43 @@ func parseTag(input string) (Tag, error) { return tag, nil } -func Parse(input string) (Foo, error) { +// TODO: pointernes +func Parse(name, input string) (TaggedJunkGenerator, error) { inputSlice := strings.Split(input, "<") - fmt.Printf("%v\n", inputSlice) if len(inputSlice) <= 1 { - return Foo{}, fmt.Errorf("empty input: %s", input) + return TaggedJunkGenerator{}, fmt.Errorf("empty input: %s", input) } + uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) + maps.Copy(uniqueTagCheck, uniqueTags) + // skip byproduct of split inputSlice = inputSlice[1:] - rv := Foo{x: make([]Generator, 0, len(inputSlice))} - + rv := newTagedJunkGenerator(name, len(inputSlice)) for _, inputParam := range inputSlice { if len(inputParam) <= 1 { - return Foo{}, fmt.Errorf("empty tag in input: %s", inputSlice) + return TaggedJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) } else if strings.Count(inputParam, ">") != 1 { - return Foo{}, fmt.Errorf("ill formated input: %s", input) + return TaggedJunkGenerator{}, fmt.Errorf("ill formated input: %s", input) } tag, _ := parseTag(inputParam) - fmt.Printf("Tag: %s, Param: %s\n", tag.Name, tag.Param) - gen, ok := validEnum[tag.Name] + creator, ok := generatorCreator[tag.Name] if !ok { - return Foo{}, fmt.Errorf("invalid tag: %s", tag.Name) + return TaggedJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) } - generator, err := gen(tag.Param) + if present, ok := uniqueTagCheck[tag.Name]; ok { + if present { + return TaggedJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) + } + uniqueTagCheck[tag.Name] = true + } + generator, err := creator(tag.Param) if err != nil { - return Foo{}, fmt.Errorf("gen: %w", err) + return TaggedJunkGenerator{}, fmt.Errorf("gen: %w", err) } - rv.x = append(rv.x, generator) + + rv.append(generator) } return rv, nil diff --git a/device/internal/junk-tag/parser_test.go b/device/internal/junk-tag/parser_test.go index d9b9642..dfd3399 100644 --- a/device/internal/junk-tag/parser_test.go +++ b/device/internal/junk-tag/parser_test.go @@ -41,6 +41,16 @@ func TestParse(t *testing.T) { args: args{input: ""}, wantErr: fmt.Errorf("invalid tag"), }, + { + name: "counter uniqueness violation", + args: args{input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "timestamp uniqueness violation", + args: args{input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, { name: "valid", args: args{input: ""}, diff --git a/device/internal/junk-tag/special_handshake_handler.go b/device/internal/junk-tag/special_handshake_handler.go new file mode 100644 index 0000000..caad83f --- /dev/null +++ b/device/internal/junk-tag/special_handshake_handler.go @@ -0,0 +1,48 @@ +package junktag + +import ( + "errors" + "time" +) + +type SpecialHandshakeHandler struct { + SpecialJunk TaggedJunkGeneratorHandler + ControlledJunk TaggedJunkGeneratorHandler + + nextItime time.Time + ITimeout time.Duration // seconds + // TODO: maybe atomic? + PacketCounter uint64 +} + +func (handler *SpecialHandshakeHandler) Validate() error { + var errs []error + if err := handler.SpecialJunk.Validate(); err != nil { + errs = append(errs, err) + } + if err := handler.ControlledJunk.Validate(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { + // TODO: distiungish between first and the rest of the packets + if !handler.isTimeToSendSpecial() { + return nil + } + + rv := handler.SpecialJunk.Generate() + + handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout)) + + return rv +} + +func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool { + return time.Now().After(handler.nextItime) +} + +func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte { + return handler.ControlledJunk.Generate() +} diff --git a/device/internal/junk-tag/tagged_junk_generator.go b/device/internal/junk-tag/tagged_junk_generator.go new file mode 100644 index 0000000..a89edac --- /dev/null +++ b/device/internal/junk-tag/tagged_junk_generator.go @@ -0,0 +1,42 @@ +package junktag + +import ( + "fmt" + "strconv" +) + +type TaggedJunkGenerator struct { + name string + packetSize int + generators []Generator +} + +func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator { + return TaggedJunkGenerator{name: name, generators: make([]Generator, size)} +} + +func (tg *TaggedJunkGenerator) append(generator Generator) { + tg.generators = append(tg.generators, generator) + tg.packetSize += generator.Size() +} + +func (tg *TaggedJunkGenerator) generate() []byte { + packet := make([]byte, 0, tg.packetSize) + for _, generator := range tg.generators { + packet = append(packet, generator.Generate()...) + } + + return packet +} + +func (t *TaggedJunkGenerator) nameIndex() (int, error) { + if len(t.name) != 2 { + return 0, fmt.Errorf("name must be 2 character long: %s", t.name) + } + + index, err := strconv.Atoi(t.name[1:2]) + if err != nil { + return 0, fmt.Errorf("name should be 2 char long: %w", err) + } + return index, nil +} diff --git a/device/internal/junk-tag/tagged_junk_generator_handler.go b/device/internal/junk-tag/tagged_junk_generator_handler.go new file mode 100644 index 0000000..d5feb61 --- /dev/null +++ b/device/internal/junk-tag/tagged_junk_generator_handler.go @@ -0,0 +1,43 @@ +package junktag + +import "fmt" + +type TaggedJunkGeneratorHandler struct { + generators []TaggedJunkGenerator + length int +} + +func (handler *TaggedJunkGeneratorHandler) AppendGenerator(generators TaggedJunkGenerator) { + handler.generators = append(handler.generators, generators) + handler.length++ +} + +// validate that packets were defined consecutively +func (handler *TaggedJunkGeneratorHandler) Validate() error { + seen := make([]bool, len(handler.generators)) + for _, generator := range handler.generators { + if index, err := generator.nameIndex(); err != nil { + return fmt.Errorf("name index: %w", err) + } else { + seen[index-1] = true + } + } + + for _, found := range seen { + if !found { + return fmt.Errorf("junk packet index should be consecutive") + } + } + + return nil +} + +func (handler *TaggedJunkGeneratorHandler) Generate() [][]byte { + var rv = make([][]byte, handler.length) + for i, generator := range handler.generators { + rv[i] = make([]byte, generator.packetSize) + copy(rv[i], generator.generate()) + } + + return rv +} diff --git a/device/junk_creator.go b/device/junk_creator.go index 3a2d3b4..1df4612 100644 --- a/device/junk_creator.go +++ b/device/junk_creator.go @@ -22,47 +22,48 @@ func NewJunkCreator(d *Device) (junkCreator, error) { } // Should be called with aSecMux RLocked -func (jc *junkCreator) createJunkPackets() ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { - return nil, nil +func (jc *junkCreator) createJunkPackets(junks *[][]byte) error { + if jc.device.awg.aSecCfg.junkPacketCount == 0 { + return nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + *junks = make([][]byte, len(*junks)+jc.device.awg.aSecCfg.junkPacketCount) + for i := range jc.device.awg.aSecCfg.junkPacketCount { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { - return nil, fmt.Errorf("Failed to create junk packet: %v", err) + return fmt.Errorf("create junk packet: %v", err) } - junks = append(junks, junk) + (*junks)[i] = junk } - return junks, nil + return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.device.awg.aSecCfg.junkPacketMinSize } // Should be called with aSecMux RLocked func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { headerJunk, err := jc.randomJunkWithSize(size) if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) + return fmt.Errorf("create header junk: %v", err) } _, err = writer.Write(headerJunk) if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) + return fmt.Errorf("write header junk: %v", err) } return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + // TODO: use a memory pool to allocate junk := make([]byte, size) _, err := jc.cha8Rand.Read(junk) return junk, err diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go index d3cf2b3..33a7a23 100644 --- a/device/junk_creator_test.go +++ b/device/junk_creator_test.go @@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { } for range [30]struct{}{} { t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || + got > jc.device.awg.aSecCfg.junkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.device.awg.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize, ) } }) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1289249..d774904 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( @@ -197,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -256,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -376,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -428,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() // lookup handshake by receiver diff --git a/device/peer.go b/device/peer.go index 5bc8ca4..bdc52fc 100644 --- a/device/peer.go +++ b/device/peer.go @@ -137,6 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { if err == nil { var totalLen uint64 for _, b := range buffers { + peer.device.awg.foo.PacketCounter++ totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) diff --git a/device/receive.go b/device/receive.go index 66c1a32..f235a33 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -149,13 +149,14 @@ func (device *Device) RoutineReceiveIncoming( if msgType == assumedMsgType { packet = packet[junkSize:] } else { - device.log.Verbosef("Transport packet lined up with another msg type") + device.log.Verbosef("transport packet lined up with another msg type") msgType = binary.LittleEndian.Uint32(packet[:4]) } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) if msgType != MessageTransportType { - device.log.Verbosef("ASec: Received message with unknown type") + // probably a junk packet + device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) continue } } @@ -245,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -304,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle cookie fields and ratelimiting @@ -456,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 7eca099..f49b0e1 100644 --- a/device/send.go +++ b/device/send.go @@ -126,10 +126,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var sendBuffer [][]byte // so only packet processed for cookie generation var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets() - peer.device.aSecMux.RUnlock() + + if peer.device.version >= VersionAwg { + junks := [][]byte{} + if peer.device.version == VersionAwgSpecialHandshake { + peer.device.awg.aSecMux.RLock() + // set junks depending on packet type + junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + if junks == nil { + junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + } + peer.device.awg.aSecMux.RUnlock() + } + peer.device.awg.aSecMux.RLock() + err := peer.device.awg.junkCreator.createJunkPackets(&junks) + peer.device.awg.aSecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -145,19 +156,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -195,19 +206,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) diff --git a/device/uapi.go b/device/uapi.go index 777bdda..7ae4b1c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync" "time" + junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -98,32 +99,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.awg.aSecCfg.junkPacketCount != 0 { + sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.aSecCfg.junkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.aSecCfg.junkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.aSecCfg.initPacketJunkSize != 0 { + sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.aSecCfg.responsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.aSecCfg.initPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.aSecCfg.responsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.aSecCfg.transportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) } } @@ -180,13 +181,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwg := awg{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -217,7 +218,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwg) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +226,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -237,7 +238,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { switch key { case "private_key": var sk NoisePrivateKey @@ -286,80 +287,113 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketCount = junkPacketCount + tempAwg.aSecCfg.isSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize + tempAwg.aSecCfg.isSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + tempAwg.aSecCfg.isSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize + tempAwg.aSecCfg.isSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempAwg.aSecCfg.isSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) + } + tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.aSecCfg.isSet = true + case "i1", "i2", "i3", "i4", "i5": + if len(value) == 0 { + return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + generators, err := junktag.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.aSecCfg.isSet = true + case "j1", "j2", "j3": + if len(value) == 0 { + return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) + } + + generators, err := junktag.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + + tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.aSecCfg.isSet = true + case "itime": + itime, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err) + } + device.log.Verbosef("UAPI: Updating itime %s", itime) + + tempAwg.handshakeHandler.ITimeout = time.Duration(itime) + tempAwg.aSecCfg.isSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) }