From a1d8adca4889e264b05879cb430d9197fda1ffcc Mon Sep 17 00:00:00 2001 From: Mark Puha
"}, wantErr: fmt.Errorf("invalid tag"), }, + { + name: "counter uniqueness violation", + args: args{input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "timestamp uniqueness violation", + args: args{input: " "}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, { name: "valid", args: args{input: " "}, diff --git a/device/internal/junk-tag/special_handshake_handler.go b/device/internal/junk-tag/special_handshake_handler.go new file mode 100644 index 0000000..caad83f --- /dev/null +++ b/device/internal/junk-tag/special_handshake_handler.go @@ -0,0 +1,48 @@ +package junktag + +import ( + "errors" + "time" +) + +type SpecialHandshakeHandler struct { + SpecialJunk TaggedJunkGeneratorHandler + ControlledJunk TaggedJunkGeneratorHandler + + nextItime time.Time + ITimeout time.Duration // seconds + // TODO: maybe atomic? + PacketCounter uint64 +} + +func (handler *SpecialHandshakeHandler) Validate() error { + var errs []error + if err := handler.SpecialJunk.Validate(); err != nil { + errs = append(errs, err) + } + if err := handler.ControlledJunk.Validate(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { + // TODO: distiungish between first and the rest of the packets + if !handler.isTimeToSendSpecial() { + return nil + } + + rv := handler.SpecialJunk.Generate() + + handler.nextItime = time.Now().Add(time.Duration(handler.ITimeout)) + + return rv +} + +func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool { + return time.Now().After(handler.nextItime) +} + +func (handler *SpecialHandshakeHandler) PrepareControlledJunk() [][]byte { + return handler.ControlledJunk.Generate() +} diff --git a/device/internal/junk-tag/tagged_junk_generator.go b/device/internal/junk-tag/tagged_junk_generator.go new file mode 100644 index 0000000..a89edac --- /dev/null +++ b/device/internal/junk-tag/tagged_junk_generator.go @@ -0,0 +1,42 @@ +package junktag + +import ( + "fmt" + "strconv" +) + +type TaggedJunkGenerator struct { + name string + packetSize int + generators []Generator +} + +func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator { + return TaggedJunkGenerator{name: name, generators: make([]Generator, size)} +} + +func (tg *TaggedJunkGenerator) append(generator Generator) { + tg.generators = append(tg.generators, generator) + tg.packetSize += generator.Size() +} + +func (tg *TaggedJunkGenerator) generate() []byte { + packet := make([]byte, 0, tg.packetSize) + for _, generator := range tg.generators { + packet = append(packet, generator.Generate()...) + } + + return packet +} + +func (t *TaggedJunkGenerator) nameIndex() (int, error) { + if len(t.name) != 2 { + return 0, fmt.Errorf("name must be 2 character long: %s", t.name) + } + + index, err := strconv.Atoi(t.name[1:2]) + if err != nil { + return 0, fmt.Errorf("name should be 2 char long: %w", err) + } + return index, nil +} diff --git a/device/internal/junk-tag/tagged_junk_generator_handler.go b/device/internal/junk-tag/tagged_junk_generator_handler.go new file mode 100644 index 0000000..d5feb61 --- /dev/null +++ b/device/internal/junk-tag/tagged_junk_generator_handler.go @@ -0,0 +1,43 @@ +package junktag + +import "fmt" + +type TaggedJunkGeneratorHandler struct { + generators []TaggedJunkGenerator + length int +} + +func (handler *TaggedJunkGeneratorHandler) AppendGenerator(generators TaggedJunkGenerator) { + handler.generators = append(handler.generators, generators) + handler.length++ +} + +// validate that packets were defined consecutively +func (handler *TaggedJunkGeneratorHandler) Validate() error { + seen := make([]bool, len(handler.generators)) + for _, generator := range handler.generators { + if index, err := generator.nameIndex(); err != nil { + return fmt.Errorf("name index: %w", err) + } else { + seen[index-1] = true + } + } + + for _, found := range seen { + if !found { + return fmt.Errorf("junk packet index should be consecutive") + } + } + + return nil +} + +func (handler *TaggedJunkGeneratorHandler) Generate() [][]byte { + var rv = make([][]byte, handler.length) + for i, generator := range handler.generators { + rv[i] = make([]byte, generator.packetSize) + copy(rv[i], generator.generate()) + } + + return rv +} diff --git a/device/junk_creator.go b/device/junk_creator.go index 3a2d3b4..1df4612 100644 --- a/device/junk_creator.go +++ b/device/junk_creator.go @@ -22,47 +22,48 @@ func NewJunkCreator(d *Device) (junkCreator, error) { } // Should be called with aSecMux RLocked -func (jc *junkCreator) createJunkPackets() ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { - return nil, nil +func (jc *junkCreator) createJunkPackets(junks *[][]byte) error { + if jc.device.awg.aSecCfg.junkPacketCount == 0 { + return nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + *junks = make([][]byte, len(*junks)+jc.device.awg.aSecCfg.junkPacketCount) + for i := range jc.device.awg.aSecCfg.junkPacketCount { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { - return nil, fmt.Errorf("Failed to create junk packet: %v", err) + return fmt.Errorf("create junk packet: %v", err) } - junks = append(junks, junk) + (*junks)[i] = junk } - return junks, nil + return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.device.awg.aSecCfg.junkPacketMinSize } // Should be called with aSecMux RLocked func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { headerJunk, err := jc.randomJunkWithSize(size) if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) + return fmt.Errorf("create header junk: %v", err) } _, err = writer.Write(headerJunk) if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) + return fmt.Errorf("write header junk: %v", err) } return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + // TODO: use a memory pool to allocate junk := make([]byte, size) _, err := jc.cha8Rand.Read(junk) return junk, err diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go index d3cf2b3..33a7a23 100644 --- a/device/junk_creator_test.go +++ b/device/junk_creator_test.go @@ -91,13 +91,13 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { } for range [30]struct{}{} { t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || + got > jc.device.awg.aSecCfg.junkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.device.awg.aSecCfg.junkPacketMinSize, + jc.device.awg.aSecCfg.junkPacketMaxSize, ) } }) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1289249..d774904 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( @@ -197,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -256,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -376,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.aSecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -428,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() // lookup handshake by receiver diff --git a/device/peer.go b/device/peer.go index 5bc8ca4..bdc52fc 100644 --- a/device/peer.go +++ b/device/peer.go @@ -137,6 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { if err == nil { var totalLen uint64 for _, b := range buffers { + peer.device.awg.foo.PacketCounter++ totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) diff --git a/device/receive.go b/device/receive.go index 66c1a32..f235a33 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -149,13 +149,14 @@ func (device *Device) RoutineReceiveIncoming( if msgType == assumedMsgType { packet = packet[junkSize:] } else { - device.log.Verbosef("Transport packet lined up with another msg type") + device.log.Verbosef("transport packet lined up with another msg type") msgType = binary.LittleEndian.Uint32(packet[:4]) } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) if msgType != MessageTransportType { - device.log.Verbosef("ASec: Received message with unknown type") + // probably a junk packet + device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) continue } } @@ -245,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -304,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.aSecMux.RLock() // handle cookie fields and ratelimiting @@ -456,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.aSecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 7eca099..f49b0e1 100644 --- a/device/send.go +++ b/device/send.go @@ -126,10 +126,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var sendBuffer [][]byte // so only packet processed for cookie generation var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets() - peer.device.aSecMux.RUnlock() + + if peer.device.version >= VersionAwg { + junks := [][]byte{} + if peer.device.version == VersionAwgSpecialHandshake { + peer.device.awg.aSecMux.RLock() + // set junks depending on packet type + junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + if junks == nil { + junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + } + peer.device.awg.aSecMux.RUnlock() + } + peer.device.awg.aSecMux.RLock() + err := peer.device.awg.junkCreator.createJunkPackets(&junks) + peer.device.awg.aSecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -145,19 +156,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -195,19 +206,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) + peer.device.awg.aSecMux.RLock() + if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) if err != nil { - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.aSecMux.RUnlock() + peer.device.awg.aSecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) diff --git a/device/uapi.go b/device/uapi.go index 777bdda..7ae4b1c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync" "time" + junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -98,32 +99,32 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.awg.aSecCfg.junkPacketCount != 0 { + sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.aSecCfg.junkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.aSecCfg.junkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.aSecCfg.initPacketJunkSize != 0 { + sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.aSecCfg.responsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.aSecCfg.initPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.aSecCfg.responsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.aSecCfg.transportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) } } @@ -180,13 +181,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwg := awg{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -217,7 +218,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwg) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +226,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -237,7 +238,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { switch key { case "private_key": var sk NoisePrivateKey @@ -286,80 +287,113 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketCount = junkPacketCount + tempAwg.aSecCfg.isSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize + tempAwg.aSecCfg.isSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + tempAwg.aSecCfg.isSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize + tempAwg.aSecCfg.isSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempAwg.aSecCfg.isSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.aSecCfg.isSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) + } + tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.aSecCfg.isSet = true + case "i1", "i2", "i3", "i4", "i5": + if len(value) == 0 { + return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + generators, err := junktag.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.aSecCfg.isSet = true + case "j1", "j2", "j3": + if len(value) == 0 { + return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) + } + + generators, err := junktag.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + + tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.aSecCfg.isSet = true + case "itime": + itime, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err) + } + device.log.Verbosef("UAPI: Updating itime %s", itime) + + tempAwg.handshakeHandler.ITimeout = time.Duration(itime) + tempAwg.aSecCfg.isSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) }